tf.data.Dataset - 删除缓存?

clo*_*udy 7 python tensorflow

是否可以删除调用后建立的内存缓存tf.data.Dataset.cache()

这就是我想做的。数据集的扩充是非常昂贵的,所以当前的代码或多或少是:

data = tf.data.Dataset(...) \
       .map(<expensive_augmentation>) \
       .cache() \
       # .shuffle().batch() etc. 
Run Code Online (Sandbox Code Playgroud)

然而,这意味着每次迭代data都会看到数据样本的相同增强版本。我想做的是使用缓存几个时期,然后重新开始,或者等效地执行类似的操作Dataset.map(<augmentation>).fleeting_cache().repeat(8)。这有可能实现吗?

AAu*_*ert 1

缓存生命周期与数据集绑定在一起,因此您可以通过重新创建数据集来实现这一点:

def create_dataset():
  dataset = tf.data.Dataset(...)
  dataset = dataset.map(<expensive_augmentation>)
  dataset = dataset.shuffle(...)
  dataset = dataset.batch(...)
  return dataset

for epoch in range(num_epochs):
  # Drop the cache every 8 epochs.
  if epoch % 8 == 0: dataset = create_dataset()
  for batch in dataset:
    train(batch)
Run Code Online (Sandbox Code Playgroud)