关于 fit_generator() / fit() 和线程安全

Mar*_*kus 8 python multithreading thread-safety multiprocessing keras

语境

为了fit_generator()在 Keras 中使用,我使用了一个像这样的伪代码-one的生成器函数:

def generator(data: np.array) -> (np.array, np.array):
    """Simple generator yielding some samples and targets"""

    while True:
        for batch in range(number_of_batches):
            yield data[batch * length_sequence], data[(batch + 1) * length_sequence]
Run Code Online (Sandbox Code Playgroud)

在 Keras 的fit_generator()函数中,我想使用workers=4并且use_multiprocessing=True- 因此,我需要一个线程安全生成器。

在像herehere或Keras docs这样的stackoverflow的答案中,我读到了关于创建一个从Keras.utils.Sequence()这样继承的类:

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return ...
Run Code Online (Sandbox Code Playgroud)

使用SequencesKeras 不会在使用多个工作和多处理时发出任何警告;生成器应该是线程安全的。

无论如何,由于我正在使用我的自定义函数,我偶然发现了github上提供的 Omer Zohars 代码,它允许generator()通过添加装饰器来使我的线程安全。代码如下:

import threading

class threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe."""
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))

    return g
Run Code Online (Sandbox Code Playgroud)

现在我可以这样做:

@threadsafe_generator
def generator(data):
    ...
Run Code Online (Sandbox Code Playgroud)

问题是:使用这个版本的线程安全生成器 Keras 仍然会发出警告,即生成器在使用时必须是线程安全的workers > 1use_multiprocessing=True并且可以通过使用Sequences.


我现在的问题是:

  1. Keras 是否仅因为生成器没有继承而发出此警告Sequences,还是 Keras 还检查生成器是否是线程安全的?
  2. 是否使用我选择的线程安全方法作为使用Keras-docs 中generatorClass(Sequence)-version ?
  3. 是否有任何其他方法导致 Keras 可以处理与这两个示例不同的线程安全生成器?


编辑: 在较新的tensorflow/ keras-versions ( tf> 2)fit_generator()中已弃用。相反,建议fit()与生成器一起使用。但是,这个问题仍然适用于fit()使用生成器。

Mar*_*kus 12

在我对此进行研究期间,我发现了一些信息来回答我的问题。

注意:如更新的问题中更新tensorflow/ keras-versions ( tf> 2)fit_generator()已弃用。相反,建议fit()与生成器一起使用。但是,答案仍然适用于fit()使用生成器。


1. Keras 发出这个警告只是因为生成器没有继承 Sequences,还是 Keras 还检查生成器是否是线程安全的?

取自 Keras 的 gitRepo ( training_generators.py ),我在46-52以下几行中找到了:

use_sequence_api = is_sequence(generator)
if not use_sequence_api and use_multiprocessing and workers > 1:
    warnings.warn(
        UserWarning('Using a generator with `use_multiprocessing=True`'
                    ' and multiple workers may duplicate your data.'
                    ' Please consider using the `keras.utils.Sequence'
                    ' class.'))
Run Code Online (Sandbox Code Playgroud)

is_sequence()training_utils.py中获取的定义624-635是:

def is_sequence(seq):
    """Determine if an object follows the Sequence API.
    # Arguments
        seq: a possible Sequence object
    # Returns
        boolean, whether the object follows the Sequence API.
    """
    # TODO Dref360: Decide which pattern to follow. First needs a new TF Version.
    return (getattr(seq, 'use_sequence_api', False)
            or set(dir(Sequence())).issubset(set(dir(seq) + ['use_sequence_api'])))
Run Code Online (Sandbox Code Playgroud)

关于这段代码,Keras 只检查传递的生成器是否是 Keras 序列(或者更确切地说是使用 Keras 的序列 API),并且通常不检查生成器是否是线程安全的。


2.是否使用我选择的线程安全方法作为使用Keras-docs 中的 generatorClass(Sequence)-version ?

正如 Omer Zohar 在gitHub上所展示的,他的装饰器是线程安全的——我看不出有任何理由说明它对Keras来说不应该是线程安全的(尽管 Keras 会发出警告,如 1 所示)。thread.Lock()根据文档,的实现可以被认为是线程安全的:

返回一个新的原始锁对象的工厂函数。一旦线程获取了它,后续获取它的尝试就会阻塞,直到它被释放;任何线程都可以释放它。

生成器也是可酸洗的,可以像这样进行测试(有关更多信息,请参见此处的SO-Q&A ):

#Dump yielded data in order to check if picklable
with open("test.pickle", "wb") as outfile:
    for yielded_data in generator(data):
        pickle.dump(yielded_data, outfile, protocol=pickle.HIGHEST_PROTOCOL)
Run Code Online (Sandbox Code Playgroud)

恢复这一点,我什至建议thread.Lock()您在扩展 Keras 时实施,Sequence()例如:

import threading

class generatorClass(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.lock = threading.Lock()   #Set self.lock

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        with self.lock:                #Use self.lock
            batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
            batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

            return ...
Run Code Online (Sandbox Code Playgroud)

24/04/2020 编辑:

通过使用self.lock = threading.Lock()您可能会遇到以下错误:

类型错误:无法pickle _thread.lock 对象

如果发生这种情况尝试更换with self.lock:内部__getitem__with threading.Lock():和注释掉/删除self.lock = threading.Lock()里面的__init__

lock-object存储在类中时似乎存在一些问题(例如参见问答)。


3.是否有任何其他方法导致 Keras 可以处理的线程安全生成器与这两个示例不同?

在我的研究过程中,我没有遇到任何其他方法。当然,我不能100%肯定地说。