har*_*ash 18 python machine-learning neural-network keras
是否可以获取使用加载的文件名flow_from_directory
?我有 :
datagen = ImageDataGenerator(
rotation_range=3,
# featurewise_std_normalization=True,
fill_mode='nearest',
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True
)
train_generator = datagen.flow_from_directory(
path+'/train',
target_size=(224, 224),
batch_size=batch_size,)
Run Code Online (Sandbox Code Playgroud)
我有一个自定义生成器用于我的多输出模型,如:
a = np.arange(8).reshape(2, 4)
# print(a)
print(train_generator.filenames)
def generate():
while 1:
x,y = train_generator.next()
yield [x] ,[a,y]
Run Code Online (Sandbox Code Playgroud)
节点,此刻我正在a
为实际训练生成随机数,我希望加载一个json
包含我的图像的边界框坐标的文件.为此,我需要获取使用train_generator.next()
方法生成的文件名.在我有了之后,我可以加载文件,解析json
并传递它而不是a
.x
变量的排序和我得到的文件名列表也是必须的.
Pic*_*ard 23
是的,至少在版本2.0.4(不知道早期版本)是可能的.
实例ImageDataGenerator().flow_from_directory(...)
具有一个属性,filenames
其中包含生成器生成它们的顺序中的所有文件的列表以及属性batch_index
.所以你可以这样做:
datagen = ImageDataGenerator()
gen = datagen.flow_from_directory(...)
Run Code Online (Sandbox Code Playgroud)
在生成器的每次迭代中,您都可以获得相应的文件名,如下所示:
for i in gen:
idx = (gen.batch_index - 1) * gen.batch_size
print(gen.filenames[idx : idx + gen.batch_size])
Run Code Online (Sandbox Code Playgroud)
这将为您提供当前批次中图像的文件名.
您可以创建一个非常小的子类,image, file_path
通过继承来返回元组DirectoryIterator
:
import numpy as np
from keras.preprocessing.image import ImageDataGenerator, DirectoryIterator
class ImageWithNames(DirectoryIterator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.filenames_np = np.array(self.filepaths)
self.class_mode = None # so that we only get the images back
def _get_batches_of_transformed_samples(self, index_array):
return (super()._get_batches_of_transformed_samples(index_array),
self.filenames_np[index_array])
Run Code Online (Sandbox Code Playgroud)
在 init 中,我添加了一个 numpy 版本的属性,self.filepaths
以便我们可以轻松地对该数组进行索引以获取每个批次生成的路径。
对基类的唯一其他更改是返回一个元组,即图像批处理super()._get_batches_of_transformed_samples(index_array)
和文件路径self.filenames_np[index_array]
。
有了这个,你可以像这样制作你的生成器:
imagegen = ImageDataGenerator()
datagen = ImageWithNames('/data/path', imagegen, target_size=(224,224))
Run Code Online (Sandbox Code Playgroud)
然后检查
next(datagen)
Run Code Online (Sandbox Code Playgroud)
至少在2.2.4版本中,你可以这样做
datagen = ImageDataGenerator()
gen = datagen.flow_from_directory(...)
for file in gen.filenames:
print(file)
Run Code Online (Sandbox Code Playgroud)
或获取文件路径
for filepath in gen.filepaths:
print(filepath)
Run Code Online (Sandbox Code Playgroud)