Min*_*ark 5 python ragged tensorflow
我需要通过沿参差不齐的维度进行索引来获取参差不齐的张量中的值。一些索引有效([:, :x],[:, -x:]或[:, x:y]),但不能直接索引([:, x]):
R = tf.RaggedTensor.from_tensor([[1, 2, 3], [4, 5, 6]])
print(R[:, :2]) # RaggedTensor([[1, 2], [4, 5]])
print(R[:, 1:2]) # RaggedTensor([[2], [5]])
print(R[:, 1]) # ValueError: Cannot index into an inner ragged dimension.
Run Code Online (Sandbox Code Playgroud)
该文档解释了为何失败:
RaggedTensors 支持多维索引和切片,但有一个限制:不允许对参差不齐的维度进行索引。这种情况是有问题的,因为指示的值可能存在于某些行中,但不存在于其他行中。在这种情况下,我们是否应该 (1) 引发 IndexError 并不明显;(2) 使用默认值;或 (3) 跳过该值并返回一个比我们开始时行数更少的张量。遵循 Python 的指导原则(“面对歧义,拒绝猜测”),我们目前不允许此操作。
这是有道理的,但我如何实际实施选项 1、2 和 3?我是否必须将参差不齐的数组转换为张量的 Python 数组,然后手动迭代它们?有没有更有效的解决方案?一个可以在 TensorFlow 图中 100% 工作的,而无需通过 Python 解释器?
如果您有一个 2D RaggedTensor,那么您可以通过以下方式获得行为 (3):
def get_column_slice_v3(rt, column):
assert column >= 0 # Negative column index not supported
slice = rt[:, column:column+1]
return slice.flat_values
Run Code Online (Sandbox Code Playgroud)
您可以通过添加 rt.nrows() == tf.size(slice.flat_values) 断言来获得行为 (1):
def get_column_slice_v1(rt, column):
assert column >= 0 # Negative column index not supported
slice = rt[:, column:column+1]
with tf.assert_equal(rt.nrows(), tf.size(slice.flat_values):
return tf.identity(slice.flat_values)
Run Code Online (Sandbox Code Playgroud)
为了获得行为(2),我认为最简单的方法可能是连接默认值向量,然后再次切片:
def get_colum_slice_v2(rt, column, default=None):
assert column >= 0 # Negative column index not supported
slice = rt[:, column:column+1]
if default is None:
defaults = tf.zeros([slice.nrows(), 1], slice.dtype)
ele:
defaults = tf.fill([slice.nrows(), 1], default)
slice_plus_default = tf.concat([rt, defaults], axis=1)
slice2 = slice_plus_defaults[:1]
return slice2.flat_values
Run Code Online (Sandbox Code Playgroud)
可以扩展它们以支持更高维的不规则张量,但逻辑会变得更加复杂。此外,应该可以扩展它们以支持负列索引。
| 归档时间: |
|
| 查看次数: |
1053 次 |
| 最近记录: |