tf.split的输出是什么?

Cha*_*ine 2 python arrays tensorflow tensor

假设我有这个:

TensorShape([Dimension(None),Dimension(32)])

我在这个张量_X上使用tf.split,其尺寸如上:

_X = tf.split(_X, 128, 0) 
Run Code Online (Sandbox Code Playgroud)

这种新张量的形状是什么?输出是一个列表,因此很难知道这个新张量的形状.

Har*_*lla 10

tf.split()返回张量对象列表.你可以知道每个张量对象的形状如下

import tensorflow as tf

X = tf.random_uniform([256, 32]);
Y = tf.split(X,128,0)
Y_shape = tf.shape(Y[1])

sess = tf.Session()
X_v,Y_v,Y_shape_v = sess.run([X,Y,Y_shape]) 
# numpy style
print X_v.shape
print len(Y_v)
print Y_v[100].shape
# TF style
print len(Y)
print Y_shape_v
Run Code Online (Sandbox Code Playgroud)

输出:

(256, 32)
128
(2, 32)
128
[ 2 32]
Run Code Online (Sandbox Code Playgroud)

我希望这有帮助 !


Ank*_*sal 5

tf.split(X, row = n, column = m)用于将变量的数据集分为nm数和列数。

例如,我们有data_set x大小(10,10),然后tf.split(x, 2, 0)将打破的data_set x2集大小的(5, 10)

但是如果采取的tf.split(x, 2, 2)话,我们将得到4套大小的数据(5, 5)