我有一个名为tensorshape的 3 级张量,[batch_size, axis_1, axis_2]并希望将它batch_size沿第一个轴分成多个切片,如下所示:
batch_size = tf.shape(tensor)[0]
batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
Run Code Online (Sandbox Code Playgroud)
不幸的是,这不起作用,因为batch_size在构建图形期间尚不知道的值。
我该如何解决这个问题?
我收到此错误:
TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
Run Code Online (Sandbox Code Playgroud)
奇怪的是,尝试batch_size在其他 TensorFlow 函数中使用似乎有效:
tensor = tf.reshape(tensor, [batch_size, -1])
Run Code Online (Sandbox Code Playgroud)
尽管batch_size在图形构建过程中的值未知,但工作正常。
问题是特别多tf.split()吗?
解决方法是:
batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
elems=tf.range(batch_size),
dtype=tf.float32)
Run Code Online (Sandbox Code Playgroud)
不过,我仍然对更好的解决方案感兴趣。
| 归档时间: |
|
| 查看次数: |
1946 次 |
| 最近记录: |