TensorFlow:将张量拆分为“batch_size”切片

Ale*_*lex 5 python tensorflow

我有一个名为tensorshape的 3 级张量,[batch_size, axis_1, axis_2]并希望将它batch_size沿第一个轴分成多个切片,如下所示:

batch_size = tf.shape(tensor)[0]

batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
Run Code Online (Sandbox Code Playgroud)

不幸的是,这不起作用,因为batch_size在构建图形期间尚不知道的值。

我该如何解决这个问题?

我收到此错误:

TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
Run Code Online (Sandbox Code Playgroud)

奇怪的是,尝试batch_size在其他 TensorFlow 函数中使用似乎有效:

tensor = tf.reshape(tensor, [batch_size, -1])
Run Code Online (Sandbox Code Playgroud)

尽管batch_size在图形构建过程中的值未知,但工作正常。

问题是特别多tf.split()吗?

Ale*_*lex 1

解决方法是:

batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
                        elems=tf.range(batch_size),
                        dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)

不过,我仍然对更好的解决方案感兴趣。