使用数组进行Tensorflow哈希表查找

Mih*_* L. 6 python numpy tensorflow

我正在尝试使用可HashMap与Tensorflow一起使用的功能类型。当键和值是int类型时,我可以使用它。但是当它们是数组时,它会给出错误- ValueError: Shapes (2,) and () are not compatible在线default_value)

import numpy as np
import tensorflow as tf


input_tensor = tf.constant([1, 1], dtype=tf.int64)
keys = tf.constant(np.array([[1, 1],[2, 2],[3, 3]]),  dtype=tf.int64)
values = tf.constant(np.array([[4, 1],[5, 1],[6, 1]]),  dtype=tf.int64)
default_value = tf.constant(np.array([1, 1]),  dtype=tf.int64)

table = tf.contrib.lookup.HashTable(
        tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
        default_value)

out = table.lookup(input_tensor)
with tf.Session() as sess:
    table.init.run()
    print(out.eval())
Run Code Online (Sandbox Code Playgroud)

Pet*_*dan 6

不幸的是,tf.contrib.lookup.HashTable仅适用于一维张量。这是一个使用tf.SparseTensors 的实现,当然,仅当您的键是整数(int32或int64)张量时,该实现才有效。

对于值,我将两列存储在两个单独的张量中,但是如果您有很多列,则可能只想将它们存储在一个大张量中,并将索引存储为一个值tf.SparseTensor

此代码(经过测试):

import tensorflow as tf

lookup = tf.placeholder( shape = ( 2, ), dtype = tf.int64 )
default_value = tf.constant( [ 1, 1 ], dtype = tf.int64 )
input_tensor = tf.constant( [ 1, 1 ], dtype=tf.int64)
keys = tf.constant( [ [ 1, 2 ], [ 3, 4 ], [ 5, 6 ] ],  dtype=tf.int64 )
values = tf.constant( [ [ 4, 1 ], [ 5, 1 ], [ 6, 1 ] ],  dtype=tf.int64 )
val0 = values[ :, 0 ]
val1 = values[ :, 1 ]

st0 = tf.SparseTensor( keys, val0, dense_shape = ( 7, 7 ) )
st1 = tf.SparseTensor( keys, val1, dense_shape = ( 7, 7 ) )

x0 = tf.sparse_slice( st0, lookup, [ 1, 1 ] )
y0 = tf.reshape( tf.sparse_tensor_to_dense( x0, default_value = default_value[ 0 ] ), () )
x1 = tf.sparse_slice( st1, lookup, [ 1, 1 ] )
y1 = tf.reshape( tf.sparse_tensor_to_dense( x1, default_value = default_value[ 1 ] ), () )

y = tf.stack( [ y0, y1 ], axis = 0 )

with tf.Session() as sess:
    print( sess.run( y, feed_dict = { lookup : [ 1, 2 ] } ) )
    print( sess.run( y, feed_dict = { lookup : [ 1, 1 ] } ) )
Run Code Online (Sandbox Code Playgroud)

将输出:

[4 1]
[1 1]

根据需要(查找该值[4,1]为键[1,2]和默认值[1,1][1,1] ,这点到一个不存在的条目。)