TensorFlow 中的 tf.newaxis 操作

Cel*_*eli 16 python tensorflow

x_train = x_train[..., tf.newaxis].astype("float32")

x_test = x_test[..., tf.newaxis].astype("float32")

有人可以解释一下如何tf.newaxis工作吗?

我在文档中发现了一个简短的提及

https://www.tensorflow.org/api_docs/python/tf/strided_slice

但我无法正确理解。

Tim*_*lin 17

检查这个例子:

a = tf.constant([100])
print(a.shape) ## (1)
expanded_1 = tf.expand_dims(a,axis=1)
print(expanded_1.shape) ## (1,1)
expanded_2 = a[:, tf.newaxis]
print(expanded_2.shape) ## (1,1)
Run Code Online (Sandbox Code Playgroud)

expand_dims()与添加新轴类似。

如果要在张量的开头添加新轴,请使用

expanded_2 = a[tf.newaxis, :]
Run Code Online (Sandbox Code Playgroud)

否则(最后)

expanded_2 = a[:,tf.newaxis]
Run Code Online (Sandbox Code Playgroud)


Fur*_*sen 8

您还可以使用 向张量添加维度,同时保留相同的信息tf.newaxis

# Create a rank 2 tensor (2 dimensions)
rank_2_tensor = tf.constant([[10, 7],
                             [3, 4]])

print("dimension: ", rank_2_tensor.ndim)
print("shape    : ", rank_2_tensor.shape)
Run Code Online (Sandbox Code Playgroud)

输出:

维度:2
形状:TensorShape([2, 2])

# Add an extra dimension (to the end)
rank_3_tensor = rank_2_tensor[..., tf.newaxis] 
# in Python "..." means "all dimensions prior to"

print("dimension: ", rank_3_tensor .ndim)
print("shape    : ", rank_3_tensor .shape)
Run Code Online (Sandbox Code Playgroud)

输出:

维度:3
形状:TensorShape([2, 2, 1])

您可以使用 tf.expand_dims() 实现相同的效果。

rank_new_3_tensor = tf.expand_dims(rank_2_tensor, axis=-1) # "-1" means last axis
print("dimension: ", rank_new_3_tensor .ndim)
print("shape    : ", rank_new_3_tensor .shape)
Run Code Online (Sandbox Code Playgroud)

输出:

维度:3
形状:TensorShape([2, 2, 1])