如何在TensorFlow中切分4级张量?

Sha*_*ban 1 slice tensorflow

我试图使用tf.slice()运算符切割四维张量,如下所示:

x_image = tf.reshape(x, [-1,28,28,1], name='Images_2D')
slice_im = tf.slice(x_image,[0,2,2],[1, 24, 24])
Run Code Online (Sandbox Code Playgroud)

但是,当我尝试运行此代码时,我得到以下异常:

raise ValueError("Shape %s must have rank %d" % (self, rank))

ValueError: Shape TensorShape([Dimension(None), Dimension(28), Dimension(28), Dimension(1)]) must have rank 3
Run Code Online (Sandbox Code Playgroud)

我该如何切割这个张量?

mrr*_*rry 7

tf.slice(input, begin, size)操作者要求beginsize矢量-限定了子张要被切断-具有相同的长度,如维度数目input.因此,要切片4-D张量,必须传递四个数字的向量(或列表)作为第二个和第三个参数tf.slice().

例如:

x_image = tf.reshape(x, [-1, 28, 28, 1], name='Images_2D')

slice_im = tf.slice(x_image, [0, 2, 2, 0], [1, 24, 24, 1])

# Or, using the indexing operator:
slice_im = x_image[0:1, 2:26, 2:26, :]
Run Code Online (Sandbox Code Playgroud)

索引运算符稍微强大一些,因为它还可以降低输出的等级,如果对于维度指定单个整数而不是范围:

slice_im = x_image[0:1, 2:26, 2:26, :]
print slice_im_2d.get_shape()  # ==> [1, 24, 24, 1]

slice_im_2d = x_image[0, 2:26, 2:26, 0]
print slice_im_2d.get_shape()  # ==> [24, 24]
Run Code Online (Sandbox Code Playgroud)