相关疑难解决方法(0)

Keras / Tensorflow中的类生成器(继承序列)线程安全吗?

为了使模型的训练更快,在CPU上填充/生成批次并在GPU上并行运行模型的训练似乎是一个好习惯。为此,可以使用Python编写一个生成器类来继承Sequence该类。

这是文档的链接:https : //www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence

该文档指出的重要内容是:

Sequence是进行多处理的更安全方法。这种结构保证了网络在每个时期的每个样本上只会训练一次,而生成器则不会。

它给出了一个简单的代码示例,如下所示:

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx …
Run Code Online (Sandbox Code Playgroud)

python generator keras tensorflow

5
推荐指数
1
解决办法
1757
查看次数

标签 统计

generator ×1

keras ×1

python ×1

tensorflow ×1