如何使用TensorFlow张量索引列表?

spr*_*sel 10 python indexing list tensorflow

假设需要通过查找表访问的具有不可连接对象的列表.因此列表索引将是张量对象,但这是不可能的.

 tf_look_up = tf.constant(np.array([3, 2, 1, 0, 4]))
 index = tf.constant(2)
 list = [0,1,2,3,4]

 target = list[tf_look_up[index]]
Run Code Online (Sandbox Code Playgroud)

这将显示以下错误消息.

 TypeError: list indices must be integers or slices, not Tensor
Run Code Online (Sandbox Code Playgroud)

是使用张量索引列表的方法/解决方法吗?

sol*_*ice 12

tf.gather 是为此目的而设计的.

简单地说tf.gather(list, tf_look_up[index]),你会得到你想要的.


jbi*_*ird 2

Tensorflow 实际上支持HashTable. 请参阅文档了解更多详细信息。

在这里,您可以执行以下操作:

table = tf.contrib.lookup.HashTable(
    tf.contrib.lookup.KeyValueTensorInitializer(tf_look_up, list), -1)
Run Code Online (Sandbox Code Playgroud)

然后通过运行获取所需的输入

target = table.lookup(index)
Run Code Online (Sandbox Code Playgroud)

请注意,-1如果未找到密钥,则这是默认值。您可能需要将key_dtype和添加value_dtype到构造函数中,具体取决于张量的配置。