相关疑难解决方法(0)

关于 fit_generator() / fit() 和线程安全

语境

为了fit_generator()在 Keras 中使用,我使用了一个像这样的伪代码-one的生成器函数:

def generator(data: np.array) -> (np.array, np.array):
    """Simple generator yielding some samples and targets"""

    while True:
        for batch in range(number_of_batches):
            yield data[batch * length_sequence], data[(batch + 1) * length_sequence]
Run Code Online (Sandbox Code Playgroud)

在 Keras 的fit_generator()函数中,我想使用workers=4并且use_multiprocessing=True- 因此,我需要一个线程安全生成器。

在像herehere或Keras docs这样的stackoverflow的答案中,我读到了关于创建一个从Keras.utils.Sequence()这样继承的类:

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x …
Run Code Online (Sandbox Code Playgroud)

python multithreading thread-safety multiprocessing keras

8
推荐指数
1
解决办法
3593
查看次数