在tensorflow中展平批处理

Cac*_*tux 27 tensorflow

我有一个形状的张量流的输入[None, 9, 2](其中None是批量).

要对其执行进一步的操作(例如matmul),我需要将其转换为[None, 18]形状.怎么做?

wei*_*114 39

您可以使用tf.reshape()轻松完成,而无需了解批量大小.

x = tf.placeholder(tf.float32, shape=[None, 9,2])
shape = x.get_shape().as_list()        # a list: [None, 9, 2]
dim = numpy.prod(shape[1:])            # dim = prod(9,2) = 18
x2 = tf.reshape(x, [-1, dim])           # -1 means "all"
Run Code Online (Sandbox Code Playgroud)

-1无论batchsize是在运行什么在最后一行意味着整个列.你可以在tf.reshape()中看到它.


更新:shape = [无,3,无]

谢谢@kbrose.对于未定义多个维度的情况,我们可以选择使用tf.hape()tf.reduce_prod().

x = tf.placeholder(tf.float32, shape=[None, 3, None])
dim = tf.reduce_prod(tf.shape(x)[1:])
x2 = tf.reshape(x, [-1, dim])
Run Code Online (Sandbox Code Playgroud)

tf.shape()返回一个可以在运行时计算的形状Tensor.可以在doc中看到tf.get_shape()和tf.shape()之间的区别.

我还尝试了另一个tf.contrib.layers.flatten().对于第一种情况来说这是最简单的,但它无法处理第二种情况.


use*_*100 15

flat_inputs = tf.layers.flatten(inputs)
Run Code Online (Sandbox Code Playgroud)


Yar*_*tov 3

您可以使用动态重塑来通过tf.batch在运行时获取批量维度的值,将整组新维度计算到 中tf.reshape。这是一个在不知道列表长度的情况下将平面列表重塑为方阵的示例。

tf.reset_default_graph()
sess = tf.InteractiveSession("")
a = tf.placeholder(dtype=tf.int32)
# get [9]
ashape = tf.shape(a)
# slice the list from 0th to 1st position
ashape0 = tf.slice(ashape, [0], [1])
# reshape list to scalar, ie from [9] to 9
ashape0_flat = tf.reshape(ashape0, ())
# tf.sqrt doesn't support int, so cast to float
ashape0_flat_float = tf.to_float(ashape0_flat)
newshape0 = tf.sqrt(ashape0_flat_float)
# convert [3, 3] Python list into [3, 3] Tensor
newshape = tf.pack([newshape0, newshape0])
# tf.reshape doesn't accept float, so convert back to int
newshape_int = tf.to_int32(newshape)
a_reshaped = tf.reshape(a, newshape_int)
sess.run(a_reshaped, feed_dict={a: np.ones((9))})
Run Code Online (Sandbox Code Playgroud)

你应该看到

array([[1, 1, 1],
       [1, 1, 1],
       [1, 1, 1]], dtype=int32)
Run Code Online (Sandbox Code Playgroud)