如何编写一个有状态的Keras fit_generator的生成器?

00_*_*_00 3 python machine-learning generator neural-network keras

我正在尝试将大型数据集提供给keras模型。数据集不适合内存。当前存储为一系列hd5f文件

我想使用训练我的模型

model.fit_generator(my_gen, steps_per_epoch=30, epochs=10, verbose=1)
Run Code Online (Sandbox Code Playgroud)

但是,在我可以在线找到的所有示例中,这些示例my_gen仅用于对已加载的数据集执行数据扩充。例如

def generator(features, labels, batch_size):

 # Create empty arrays to contain batch of features and labels#

 batch_features = np.zeros((batch_size, 64, 64, 3))
 batch_labels = np.zeros((batch_size,1))

 while True:
   for i in range(batch_size):
     # choose random index in features
     index= random.choice(len(features),1)
     batch_features[i] = some_processing(features[index])
     batch_labels[i] = labels[index]
   yield batch_features, batch_labels
Run Code Online (Sandbox Code Playgroud)

就我而言,它必须像

def generator(features, labels, batch_size):    
 while True:
   for i in range(batch_size):
     # choose random index in features
     index= # SELECT THE NEXT FILE
     batch_features[i] = some_processing(features[files[index]])
     batch_labels[i] = labels[file[index]]
   yield batch_features, batch_labels
Run Code Online (Sandbox Code Playgroud)

如何跟踪上一批中已读取的文件?

den*_*s-w 6

来自keras文档

生成器:生成器或Sequence(keras.utils.Sequence)对象的实例,以便在使用多重处理时避免重复数据。[...]

这意味着您可以编写从keras.utils.sequence继承的类

class ProductSequence(keras.utils.Sequence):
    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass
Run Code Online (Sandbox Code Playgroud)

__init__ist来初始化课程。 __len__应该返回每个时期的批次数。Keras将使用它来知道可以将哪个索引传递给__getitem____getitem__然后将根据索引返回批处理数据。一个简单的例子可以在这里找到

使用这种方法,您可以简单地拥有一个内部类对象,在其中保存已读取的文件。