在 Tensorflow 中删除张量的维度

lam*_*ung 6 python slice tensorflow tensor

我有一个有形状的张量,(50, 100, 1, 512)我想重塑它或删除第三维,以便新的张量有形状(50, 100, 512)

我曾尝试tf.slicetf.squeeze

a = tf.slice(a, [50, 100, 1, 512], [50, 100, 1, 512])
b = tf.squeeze(a)
Run Code Online (Sandbox Code Playgroud)

当我尝试打印 的形状时,一切似乎都在工作ab但是当我开始训练我的模型时,出现了这个错误

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected size[0] in [0, 0], but got 50
     [[Node: Slice = Slice[Index=DT_INT32, T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](MaxPool_2, Slice/begin, Slice/size)]]
Run Code Online (Sandbox Code Playgroud)

我的slice. 我该如何解决。谢谢

mni*_*nis 8

有多种方法可以做到这一点。Tensorflow 已开始支持索引。尝试

a = a[:, :, 0, :]
Run Code Online (Sandbox Code Playgroud)

或者

a = a[:, :, -1, :]
Run Code Online (Sandbox Code Playgroud)

或者

a = tf.reshape(a, [50, 100, 512])
Run Code Online (Sandbox Code Playgroud)

或者

a = tf.squeeze(a)
Run Code Online (Sandbox Code Playgroud)


小智 7

一般tf.squeeze都会掉尺寸。

a = tf.constant([[[1,2,3],[3,4,5]]])
Run Code Online (Sandbox Code Playgroud)

上面的张量形状是[1,2,3]。执行挤压操作后,

b = tf.squeeze(a)
Run Code Online (Sandbox Code Playgroud)

现在,张量形状是 [2,3]