我发现索引仍然是tensorflow中的一个开放问题(#206),所以我想知道我现在可以使用什么作为解决方法.我想基于每个训练示例更改的变量来索引/切片矩阵的行/列.
到目前为止我尝试过的:
以下(工作)代码片基于固定数字.
import tensorflow as tf
import numpy as np
x = tf.placeholder("float")
y = tf.slice(x,[0],[1])
#initialize
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
#run
result = sess.run(y, feed_dict={x:[1,2,3,4,5]})
print(result)
Run Code Online (Sandbox Code Playgroud)
但是,似乎我不能简单地用tf.placeholder替换其中一个固定数字.下面的代码给出了错误"TypeError:预期单个Tensor时的张量列表".
import tensorflow as tf
import numpy as np
x = tf.placeholder("float")
i = tf.placeholder("int32")
y = tf.slice(x,[i],[1])
#initialize
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
#run
result = sess.run(y, feed_dict={x:[1,2,3,4,5],i:0})
print(result)
Run Code Online (Sandbox Code Playgroud)
这听起来像[i]周围的括号太多,但删除它们也无济于事.如何使用占位符/变量作为索引?
我也尝试使用普通的python变量作为索引.这不会导致错误,但网络在训练时不会学到任何东西.我想因为更改的变量没有正确注册,图表格式错误,更新不起作用?
我找到的一个解决方法是使用单热矢量.在numpy中创建一个热矢量,使用占位符传递它,然后通过矩阵乘法进行切片.这有效,但速度很慢.
任何想法如何基于变量有效切片/索引?