keras.utils.Sequence 如何工作?

Gau*_*mse 2 oop python-3.x keras tensorflow

我正在尝试为 U-net 创建一个用于图像分割的数据管道。我遇到了Keras.utils.Sequence一个类,通过它我可以创建一个数据管道,但我无法理解它是如何工作的。

代码链接Keras 代码源代码

  def __iter__(self):
    """Create a generator that iterate over the Sequence."""
    for item in (self[i] for i in range(len(self))):
      yield item
Run Code Online (Sandbox Code Playgroud)

如果有人能告诉我这是如何工作的,我将不胜感激?

elb*_*lbe 11

你不需要发电机。序列类就是用来管理它的。您需要定义一个继承自的类tensorflow.keras.utils.Sequence并定义方法: __init__, __getitem__, __len__。此外,您还可以定义方法on_epoch_end,该方法在每个 epoch 结束时调用,通常用于对样本索引进行洗牌。您提供的链接中有一个示例Tensorflow Sequence。下面是序列的另一个示例。请注意,您可以将数据传递给__init__构造函数,但您也可以从方法中的文件中读取数据__getitem__,假设您知道在哪里读取数据,例如,通过将一个或多个目录的名称传递给构造函数。如果有大量数据,这是必要的。

from tensorflow import keras
import numpy as np

class SequenceExample(keras.utils.Sequence):

    def __init__(self, x_in, y_in, batch_size, shuffle=True):
        # Initialization
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.x = x_in
        self.y = y_in
        self.datalen = len(y_in)
        self.indexes = np.arange(self.datalen)
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __getitem__(self, index):
        # get batch indexes from shuffled indexes
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        x_batch = self.x[batch_indexes]
        y_batch = self.y[batch_indexes]
        return x_batch, y_batch
    
    def __len__(self):
        # Denotes the number of batches per epoch
        return self.datalen // self.batch_size

    def on_epoch_end(self):
        # Updates indexes after each epoch
        self.indexes = np.arange(self.datalen)
        if self.shuffle:
            np.random.shuffle(self.indexes)
Run Code Online (Sandbox Code Playgroud)