Gau*_*mse 2 oop python-3.x keras tensorflow
我正在尝试为 U-net 创建一个用于图像分割的数据管道。我遇到了Keras.utils.Sequence一个类,通过它我可以创建一个数据管道,但我无法理解它是如何工作的。
def __iter__(self):
"""Create a generator that iterate over the Sequence."""
for item in (self[i] for i in range(len(self))):
yield item
Run Code Online (Sandbox Code Playgroud)
如果有人能告诉我这是如何工作的,我将不胜感激?
elb*_*lbe 11
你不需要发电机。序列类就是用来管理它的。您需要定义一个继承自的类tensorflow.keras.utils.Sequence并定义方法:
__init__, __getitem__, __len__。此外,您还可以定义方法on_epoch_end,该方法在每个 epoch 结束时调用,通常用于对样本索引进行洗牌。您提供的链接中有一个示例Tensorflow Sequence。下面是序列的另一个示例。请注意,您可以将数据传递给__init__构造函数,但您也可以从方法中的文件中读取数据__getitem__,假设您知道在哪里读取数据,例如,通过将一个或多个目录的名称传递给构造函数。如果有大量数据,这是必要的。
from tensorflow import keras
import numpy as np
class SequenceExample(keras.utils.Sequence):
def __init__(self, x_in, y_in, batch_size, shuffle=True):
# Initialization
self.batch_size = batch_size
self.shuffle = shuffle
self.x = x_in
self.y = y_in
self.datalen = len(y_in)
self.indexes = np.arange(self.datalen)
if self.shuffle:
np.random.shuffle(self.indexes)
def __getitem__(self, index):
# get batch indexes from shuffled indexes
batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
x_batch = self.x[batch_indexes]
y_batch = self.y[batch_indexes]
return x_batch, y_batch
def __len__(self):
# Denotes the number of batches per epoch
return self.datalen // self.batch_size
def on_epoch_end(self):
# Updates indexes after each epoch
self.indexes = np.arange(self.datalen)
if self.shuffle:
np.random.shuffle(self.indexes)
Run Code Online (Sandbox Code Playgroud)