Joh*_*ohn 6 python machine-learning deep-learning tensorflow
我想将数据集迭代器传递给函数,但该函数需要知道数据集的长度。在下面的例子,我可以传递len(datafiles)
到my_custom_fn()
功能,但我想知道如果我能够从任一提取数据集的长度iterator
,batch_x
或batch_y
类,以便我没有将其添加为输入。
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
Run Code Online (Sandbox Code Playgroud)
谢谢!
编辑:此解决方案在我的情况下不起作用:tf.data.Dataset: how to get the dataset size (number of elements in an epoch)?
运行后
tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])
Run Code Online (Sandbox Code Playgroud)
返回
<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>
Run Code Online (Sandbox Code Playgroud)
我确实找到了适合我的解决方案。将 iterator_scope 添加到我的代码中,例如:
with tf.name_scope('iter'):
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
Run Code Online (Sandbox Code Playgroud)
然后从内部my_custom_fn
调用:
def my_custom_fn(batch_x, batch_y):
filenames = batch_x.graph.get_operation_by_name(
'iter/InputDataSet/filenames').outputs[0]
n_epoch = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/count').outputs)[0]
batch_size = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/batch_size').outputs)[0]
# lots of other stuff
Run Code Online (Sandbox Code Playgroud)
不确定这是否是最好的方法,但它似乎有效。很高兴就此提出任何建议,因为它看起来有点老套。
iterator
在迭代之前,an 的长度是未知的。您可以显式传递len(datafiles)
到该函数中,但如果您坚持数据的持久性,则可以简单地将该函数设为实例方法,并将数据集的长度存储在该my_custom_fn
方法的对象中。
不幸的是,由于iterator
它不存储任何内容,因此它会动态生成数据。然而,正如在 TensorFlow 的源代码中发现的那样,有一个“私有”变量_batch_size
用于存储批量大小。您可以在此处查看源代码:TensorFlow 源。
归档时间: |
|
查看次数: |
6939 次 |
最近记录: |