使用Tensorflow构建适用于可变批量大小的图形

hgy*_*gyp 11 python tensorflow

我使用tf.placeholders()ops来输入变量批量大小的输入,它们是2D张量,当我调用run()时,使用feed机制为这些张量提供不同的值.我有

TypeError:'Tensor'对象不可迭代.

以下是我的代码:

with graph.as_default():
    train_index_input = tf.placeholder(tf.int32, shape=(None, window_size))
    train_embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_dimension], -1.0, 1.0))
    embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]
    ......
    ......
Run Code Online (Sandbox Code Playgroud)

由于我无法在不运行图形的情况下看到张量"train_index_input"的内容,因此"'Tensor'对象的错误不可迭代"会引发代码:

embedding_input = [tf.nn.embedding_lookup(train_embeddings, x) for x in train_index_input]
Run Code Online (Sandbox Code Playgroud)

我想要获得的是嵌入矩阵"embedding_input",其形状[batch_size,embedding_dimension] batch_size不固定.我是否必须在Tensorflow中定义一个新操作来嵌入2D张量的查找?或者其他任何方式吗?谢谢

dga*_*dga 8

您正在尝试for x in train_index_input通过Tensorflow占位符执行python级别的列表理解().那不行 - Python不知道tf对象里面有什么.

要完成批量嵌入查找,您可以做的只是展平您的批次:

train_indexes_flat = tf.reshape(train_index_input, [-1])
Run Code Online (Sandbox Code Playgroud)

通过嵌入查找运行它:

looked_up_embeddings = tf.nn.embedding_lookup(train_embeddings, train_indexes_flat)
Run Code Online (Sandbox Code Playgroud)

然后将其重塑为正确的组:

embedding_input = tf.reshape(looked_up_embeddings, [-1, window_size])
Run Code Online (Sandbox Code Playgroud)