使用占位符值重构张量

jst*_*er7 8 tensorflow

我想使用[int,-1]表示法重新形成张量(例如,展平图像).但我不提前知道第一个维度.一个用例是大批量训练,然后在较小批次上进行评估.

为什么会出现以下错误:got list containing Tensors of type '_Message'

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, shape=[None, 28, 28])
batch_size = tf.placeholder(tf.int32)

def reshape(_batch_size):
    return tf.reshape(x, [_batch_size, -1])

reshaped = reshape(batch_size)


with tf.Session() as sess:
    sess.run([reshaped], feed_dict={x: np.random.rand(100, 28, 28), batch_size: 100})

    # Evaluate
    sess.run([reshaped], feed_dict={x: np.random.rand(8, 28, 28), batch_size: 8})
Run Code Online (Sandbox Code Playgroud)

注意:当我在函数外部重塑它似乎工作,但我有很多次使用的非常大的模型,所以我需要将它们保存在一个函数中并使用参数传递dim.

mrr*_*rry 10

要使其工作,请替换以下函数:

def reshape(_batch_size):
    return tf.reshape(x, [_batch_size, -1])
Run Code Online (Sandbox Code Playgroud)

...具有以下功能:

def reshape(_batch_size):
    return tf.reshape(x, tf.pack([_batch_size, -1]))
Run Code Online (Sandbox Code Playgroud)

出错的原因是tf.reshape()期望一个可转换为a的值tf.Tensor作为其第二个参数.TensorFlow会自动将Python数字列表转换为a tf.Tensor但不会自动转换数字和张量的混合列表(例如a tf.placeholder()) - 而不是提高您看到的有些不直观的错误消息.

tf.pack()运算需要花费列表转换为一个张量的对象,并且每个元件独立转换,所以它可以处理一个占位符和一个整数的组合.