Him*_*oon 9 python tensorflow tensorflow-datasets tensorflow2.0
在 TensorFlow 1.X 中,您可以使用占位符动态更改批次大小。例如
dataset.batch(batch_size=tf.placeholder())
查看完整示例
你如何在 TensorFlow 2.0 中做到这一点?
我已经尝试了以下但它不起作用。
import numpy as np
import tensorflow as tf
def new_gen_function():
for i in range(100):
yield np.ones(2).astype(np.float32)
batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
batch_size=batch_size)
for data in train_ds:
print(data.shape[0])
batch_size.assign(10)
print(batch_size)
Run Code Online (Sandbox Code Playgroud)
输出
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...
Run Code Online (Sandbox Code Playgroud)
我正在使用 Gradient 磁带使用自定义训练循环训练模型。我怎样才能做到这一点?
据我所知,您应该实例化一个新的数据集迭代器以使更改生效。这将需要稍微调整以跳过已经看到的样本。
这是我最简单的解决方案:
import numpy as np
import tensorflow as tf
def get_dataset(batch_size, num_samples_seen):
return tf.data.Dataset.range(
100
).skip(
num_samples_seen
).batch(
batch_size=batch_size
)
def main():
batch_size = 1
num_samples_seen = 0
train_ds = get_dataset(batch_size, num_samples_seen)
ds_iterator = iter(train_ds)
while True:
try:
data = next(ds_iterator)
except StopIteration:
print("End of iteration")
break
print(data)
batch_size *= 2
num_samples_seen += data.shape[0]
ds_iterator = iter(get_dataset(batch_size, num_samples_seen))
print("New batch size:", batch_size)
if __name__ == "__main__":
main()
Run Code Online (Sandbox Code Playgroud)
正如您在此处看到的,您必须实例化一个新数据集(通过调用get_dataset)并更新迭代器。
我不知道这种解决方案对性能的影响。也许还有另一种解决方案需要“仅”实例化一个batch步骤而不是整个数据集。
| 归档时间: |
|
| 查看次数: |
5187 次 |
| 最近记录: |