语境
为了fit_generator()在 Keras 中使用,我使用了一个像这样的伪代码-one的生成器函数:
def generator(data: np.array) -> (np.array, np.array):
"""Simple generator yielding some samples and targets"""
while True:
for batch in range(number_of_batches):
yield data[batch * length_sequence], data[(batch + 1) * length_sequence]
Run Code Online (Sandbox Code Playgroud)
在 Keras 的fit_generator()函数中,我想使用workers=4并且use_multiprocessing=True- 因此,我需要一个线程安全生成器。
在像here或here或Keras docs这样的stackoverflow的答案中,我读到了关于创建一个从Keras.utils.Sequence()这样继承的类:
class generatorClass(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x …Run Code Online (Sandbox Code Playgroud)