小编Joh*_*ohn的帖子

确定 tf.data.Dataset Tensorflow 中的记录数

我想将数据集迭代器传递给函数,但该函数需要知道数据集的长度。在下面的例子,我可以传递len(datafiles)my_custom_fn()功能,但我想知道如果我能够从任一提取数据集的长度iteratorbatch_xbatch_y类,以便我没有将其添加为输入。

dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
Run Code Online (Sandbox Code Playgroud)

谢谢!

编辑:此解决方案在我的情况下不起作用tf.data.Dataset: how to get the dataset size (number of elements in an epoch)?

运行后

tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])
Run Code Online (Sandbox Code Playgroud)

返回

<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>
Run Code Online (Sandbox Code Playgroud)

我确实找到了适合我的解决方案。将 iterator_scope 添加到我的代码中,例如:

with tf.name_scope('iter'):
    dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
    iterator = dataset.make_initializable_iterator()
    sess.run(iterator.initializer)
    [batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# …
Run Code Online (Sandbox Code Playgroud)

python machine-learning deep-learning tensorflow

6
推荐指数
1
解决办法
6939
查看次数

训练后用占位符交换TensorFlow数据集输入管道

我正在使用新的tf.data.DatasetAPI,但似乎无法弄清楚如何执行推理。最终,我想将我的模型转换为TensorRT图并在TX2上运行它,并且我发现的所有示例都假设您tf.placeholder输入的是a 。这是我如何训练的伪代码。该[...]只是要用作占位符,因为我实际上没有运行代码。让我们不要争论该模型,因为它只是想举一个例子:

import tensorflow as tf

# Setup iterator
datain = tf.data.FixedLengthRecordDataset(datafiles, record_bytes1)
labels = tf.data.FixedLengthRecordDataset(labelfiles, record_bytes2)
dataset = tf.data.Dataset.zip((datain, labels))
dataset = dataset.prefetch(batch_size)
dataset = dataset.repeat(n_epoch)
iterator = dataset.make_initializable_iterator()

sess = tf.Session()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()

# Define model function (let's not debate model except as relevant to question)
def model_fn(xin):
    x0 = tf.transpose(tf.reshape(xin, [...], name='input'))
    w = tf.Variable(tf.truncated_normal([...], stddev=0.1))
    x1 = tf.nn.conv2d(x0, w, strides=[...], padding='VALID')
    b = tf.Variable(tf.constant(0.0, shape=[...])) …
Run Code Online (Sandbox Code Playgroud)

python tensorflow tensorrt

5
推荐指数
1
解决办法
1035
查看次数