关于tf.stack()轴的质询

nin*_*hut 1 tensorflow

tf.stack()tensorflow stack上阅读了文档。页面上有一个示例:

>>> x = tf.constant([1, 4])
>>> y = tf.constant([2, 5])
>>> z = tf.constant([3, 6])
>>> sess=tf.Session()
>>> sess.run(tf.stack([x, y, z]))
array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)
>>> sess.run(tf.stack([x, y, z], axis=1))
array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
Run Code Online (Sandbox Code Playgroud)

我不明白的是第二个例子axis=1

从结果看来,它首先将三个输入行转换为列

然后将它们沿着拖走axis=1,但是

我认为结果应该是

array([[1,4, 2, 5, 3, 6 ]] dtype=int32 )
Run Code Online (Sandbox Code Playgroud)

有人可以帮忙解释一下吗?

谢谢!

jde*_*esa 5

tf.stack总是添加一个新维度,并始终沿着该新维度连接给定张量。在您的情况下,您有三个带有shape的张量[2]。设置axis=0与添加新的第一维相同,因此每个张量现在将具有形状[1, 2],并在该维上串联,因此最终形状将为[3, 2]。也就是说,每个张量将是最终张量的“行”。随着axis=1每个张量的形状会扩展到上[2, 1],结果将形成形状[2, 3]。因此,每个给定的张量将是结果张量的“列”。

换句话说,tf.stack在功能上与此等效:

def tf.stack(tensors, axis=0):
    return tf.concatenate([tf.expand_dims(t, axis=axis) for t in tensors], axis=axis)
Run Code Online (Sandbox Code Playgroud)

但是,您期望的结果将通过以下方式获得:

tf.concatenate([tf.expand_dims(t, axis=0) for t in tensors], axis=1)
Run Code Online (Sandbox Code Playgroud)

请注意,在这种情况下,添加的维和串联的维是不同的。