我tf.stack()在tensorflow stack上阅读了文档。页面上有一个示例:
>>> x = tf.constant([1, 4])
>>> y = tf.constant([2, 5])
>>> z = tf.constant([3, 6])
>>> sess=tf.Session()
>>> sess.run(tf.stack([x, y, z]))
array([[1, 4],
[2, 5],
[3, 6]], dtype=int32)
>>> sess.run(tf.stack([x, y, z], axis=1))
array([[1, 2, 3],
[4, 5, 6]], dtype=int32)
Run Code Online (Sandbox Code Playgroud)
我不明白的是第二个例子axis=1。
从结果看来,它首先将三个输入行转换为列
然后将它们沿着拖走axis=1,但是
我认为结果应该是
array([[1,4, 2, 5, 3, 6 ]] dtype=int32 )
Run Code Online (Sandbox Code Playgroud)
有人可以帮忙解释一下吗?
谢谢!
tf.stack总是添加一个新维度,并始终沿着该新维度连接给定张量。在您的情况下,您有三个带有shape的张量[2]。设置axis=0与添加新的第一维相同,因此每个张量现在将具有形状[1, 2],并在该维上串联,因此最终形状将为[3, 2]。也就是说,每个张量将是最终张量的“行”。随着axis=1每个张量的形状会扩展到上[2, 1],结果将形成形状[2, 3]。因此,每个给定的张量将是结果张量的“列”。
换句话说,tf.stack在功能上与此等效:
def tf.stack(tensors, axis=0):
return tf.concatenate([tf.expand_dims(t, axis=axis) for t in tensors], axis=axis)
Run Code Online (Sandbox Code Playgroud)
但是,您期望的结果将通过以下方式获得:
tf.concatenate([tf.expand_dims(t, axis=0) for t in tensors], axis=1)
Run Code Online (Sandbox Code Playgroud)
请注意,在这种情况下,添加的维和串联的维是不同的。
| 归档时间: |
|
| 查看次数: |
2084 次 |
| 最近记录: |