Zac*_*ach 5 keras tensorflow tf.keras
我想我会分享一些我花了一段时间才弄明白的东西:用 TF 数据集对象轻松包装现有的 Keras 序列类。在遵循教程并从 TF 1.X 和 Keras 迁移到 TF 2.XI 之后,终于找到了如何用最少的代码来做到这一点。希望我不是唯一一个为此而苦苦挣扎的人,其他人会发现这很有帮助:)
几个假设:
序列类加载数据和标签
标签与源数据具有相同的形状(除了通道)(即这是我用来训练 U-Nets 的东西)
数据格式是通道最后
import tensorflow as tf
def DatasetFromSequenceClass(sequenceClass, stepsPerEpoch, nEpochs, batchSize, dims=[512,512,3], n_classes=2, data_type=tf.float32, label_type=tf.float32):
# eager execution wrapper
def DatasetFromSequenceClassEagerContext(func):
def DatasetFromSequenceClassEagerContextWrapper(batchIndexTensor):
# Use a tf.py_function to prevent auto-graph from compiling the method
tensors = tf.py_function(
func,
inp=[batchIndexTensor],
Tout=[data_type, label_type]
)
# set the shape of the tensors - assuming channels last
tensors[0].set_shape([batchSize, dims[0], dims[1], dims[2]]) # [samples, height, width, nChannels]
tensors[1].set_shape([batchSize, dims[0], dims[1], n_classes]) # [samples, height, width, nClasses for one hot]
return tensors
return DatasetFromSequenceClassEagerContextWrapper
# TF dataset wrapper that indexes our sequence class
@DatasetFromSequenceClassEagerContext
def LoadBatchFromSequenceClass(batchIndexTensor):
# get our index as numpy value - we can use .numpy() because we have wrapped our function
batchIndex = batchIndexTensor.numpy()
# zero-based index for what batch of data to load; i.e. goes to 0 at stepsPerEpoch and starts cound over
zeroBatch = batchIndex % stepsPerEpoch
# load data
data, labels = sequenceClass[zeroBatch]
# convert to tensors and return
return tf.convert_to_tensor(data), tf.convert_to_tensor(labels)
# create our data set for how many total steps of training we have
dataset = tf.data.Dataset.range(stepsPerEpoch*nEpochs)
# return dataset using map to load our batches of data, use TF to specify number of parallel calls
return dataset.map(LoadBatchFromSequenceClass, num_parallel_calls=tf.data.experimental.AUTOTUNE)
Run Code Online (Sandbox Code Playgroud)
使用该功能,您可以将训练更新为如下所示:
# load our data as tensorflow datasets
training = DatasetFromSequenceClass(trainingSequence, training_steps, nEpochs, batchSize, dims=shp, n_classes=nClasses)
validation = DatasetFromSequenceClass(validationSequence, validation_steps, nEpochs, batchSize, dims=shp, n_classes=nClasses)
# train
model_object.fit(training,
steps_per_epoch=training_steps,
validation_data=validation,
validation_steps=validation_steps,
epochs=nEpochs,
callbacks=callbacks,
verbose=1)
Run Code Online (Sandbox Code Playgroud)
从这里开始,Dataset API 有很多其他选项(如预取),但这应该是一个很好的起点。
| 归档时间: |
|
| 查看次数: |
646 次 |
| 最近记录: |