计算编辑距离(feed_dict错误)

nfm*_*ure 5 sparse-matrix levenshtein-distance tensorflow

我在 Tensorflow 中编写了一些代码来计算一个字符串和一组字符串之间的编辑距离。我无法找出错误。

import tensorflow as tf
sess = tf.Session()

# Create input data
test_string = ['foo']
ref_strings = ['food', 'bar']

def create_sparse_vec(word_list):
    num_words = len(word_list)
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
    chars = list(''.join(word_list))
    return(tf.SparseTensor(indices, chars, [num_words,1,1]))


test_string_sparse = create_sparse_vec(test_string*len(ref_strings))
ref_string_sparse = create_sparse_vec(ref_strings)

sess.run(tf.edit_distance(test_string_sparse, ref_string_sparse, normalize=True))
Run Code Online (Sandbox Code Playgroud)

该代码有效,运行时会产生输出:

array([[ 0.25],
       [ 1.  ]], dtype=float32)
Run Code Online (Sandbox Code Playgroud)

但是,当我尝试通过稀疏占位符输入稀疏张量来执行此操作时,出现错误。

array([[ 0.25],
       [ 1.  ]], dtype=float32)
Run Code Online (Sandbox Code Playgroud)

这是错误回溯:

Traceback (most recent call last):

  File "<ipython-input-29-4e06de0b7af3>", line 1, in <module>
    sess.run(edit_distances, feed_dict=feed_dict)

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 372, in run
run_metadata_ptr)

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 597, in _run
    for subfeed, subfeed_val in _feed_fn(feed, feed_val):

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 558, in _feed_fn
    return feed_fn(feed, feed_val)

  File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 268, in <lambda>
    [feed.indices, feed.values, feed.shape], feed_val)),

TypeError: zip argument #2 must support iteration
Run Code Online (Sandbox Code Playgroud)

知道这是怎么回事吗?

mrr*_*rry 4

TL;DR:对于 的返回类型create_sparse_vec(),请使用tf.SparseTensorValue代替tf.SparseTensor

这里的问题来自于 的返回类型create_sparse_vec(),它是tf.SparseTensor,并且不被理解为调用中的提要sess.run()

当您提供 (dense) 时tf.Tensor,预期值类型是 NumPy 数组(或可以转换为数组的某些对象)。当您提供 a 时tf.SparseTensor,预期值类型是 a tf.SparseTensorValue,它与 a 类似,tf.SparseTensor但它的、 和属性是 NumPy 数组(或某些可以转换为数组的对象,例如示例中的列表)。indicesvaluesshape

以下代码应该可以工作:

def create_sparse_vec(word_list):
    num_words = len(word_list)
    indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
    chars = list(''.join(word_list))
    return tf.SparseTensorValue(indices, chars, [num_words,1,1])
Run Code Online (Sandbox Code Playgroud)