如何在 Tensorflow 2.0 数据集中动态更改批量大小?

Him*_*oon 9 python tensorflow tensorflow-datasets tensorflow2.0

在 TensorFlow 1.X 中,您可以使用占位符动态更改批次大小。例如

dataset.batch(batch_size=tf.placeholder())
查看完整示例

你如何在 TensorFlow 2.0 中做到这一点?

我已经尝试了以下但它不起作用。

import numpy as np
import tensorflow as tf


def new_gen_function():
    for i in range(100):
        yield np.ones(2).astype(np.float32)


batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
    batch_size=batch_size)

for data in train_ds:
    print(data.shape[0])
    batch_size.assign(10)
    print(batch_size)
Run Code Online (Sandbox Code Playgroud)

输出

5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...
Run Code Online (Sandbox Code Playgroud)

我正在使用 Gradient 磁带使用自定义训练循环训练模型。我怎样才能做到这一点?

Ale*_*NON 1

据我所知,您应该实例化一个新的数据集迭代器以使更改生效。这将需要稍微调整以跳过已经看到的样本。

这是我最简单的解决方案:

import numpy as np
import tensorflow as tf

def get_dataset(batch_size, num_samples_seen):
    return tf.data.Dataset.range(
        100
    ).skip(
        num_samples_seen
    ).batch(
        batch_size=batch_size
    )

def main():
    batch_size = 1
    num_samples_seen = 0

    train_ds = get_dataset(batch_size, num_samples_seen)

    ds_iterator = iter(train_ds)
    while True:
        try:
            data = next(ds_iterator)
        except StopIteration:
            print("End of iteration")
            break

        print(data)
        batch_size *= 2
        num_samples_seen += data.shape[0]
        ds_iterator = iter(get_dataset(batch_size, num_samples_seen))
        print("New batch size:", batch_size)

if __name__ == "__main__":
    main()
Run Code Online (Sandbox Code Playgroud)

正如您在此处看到的,您必须实例化一个新数据集(通过调用get_dataset)并更新迭代器。

我不知道这种解决方案对性能的影响。也许还有另一种解决方案需要“仅”实例化一个batch步骤而不是整个数据集。