Tensorflow元组具有不同的形状

Mar*_*edt 5 python factorization tensorflow

我有一个问题,返回一个两个变量的元组v,wt其中vhas shape=(20,20)wthas shape=(1,).wt是一个重量值的变量.我想在一个内部返回元组(v,wt)map_fn

我的代码看起来有点接近这个

tf.map_fn(fn, nonzeros(Matrix, dim, row))

nonzeros(Matrix, dim, row) returns a (index, value)
Run Code Online (Sandbox Code Playgroud)

fn会返回一个元组,但错误输出我得到的是:

ValueError: The two structures don't have the same number of elements. First 
structure: <dtype: 'int64'>, second structure: (<tf.Tensor 
'map_2/while/while/Exit_1:0' shape=(20,) dtype=float32>, <tf.Tensor 
'map_2/while/Sub:0' shape=() dtype=int64>).
Run Code Online (Sandbox Code Playgroud)

Dav*_*rks 1

您将在此处返回循环的结果tf.while。循环tf.while返回多个值的元组,在您的情况下,我们可以看到您的 while 循环返回一个感兴趣的值和一个计数器值作为元组。

(<tf.Tensor 'map_2/while/while/Exit_1:0' shape=(20,) dtype=float32>, <tf.Tensor 'map_2/while/Sub:0' shape=() dtype=int64>)
Run Code Online (Sandbox Code Playgroud)

您要传回的内容map_fn可能只是这两个值中的第一个。因此,在此处未显示的代码中,您应该具有以下内容:

value, counter = tf.while(...)
return value
Run Code Online (Sandbox Code Playgroud)

你拥有的是:

return tf.while(...)
Run Code Online (Sandbox Code Playgroud)

因此,您看到的错误是抱怨 an与您传入的<dtype: 'int64'>不匹配。当您修复 while 循环时,您将比较可能都是 (20,) 并且将匹配的(尽管您最终可能会存在 int/float 问题)。tuple<dtype: 'int64'><tf.Tensor 'map_2/while/while/Exit_1:0' shape=(20,) dtype=float32>