我有一个张量定义如下:
temp_var = tf.Variable(initial_value=np.asarray([[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12]]))
Run Code Online (Sandbox Code Playgroud)
我还有一个从张量中获取的行索引数组:
idx = tf.constant([0, 2])
Run Code Online (Sandbox Code Playgroud)
现在我想temp_var在这些索引中采用一个子集,即idx
我知道要采用单个索引或切片,我们可以做类似的事情
temp_var[single_row_index, :]
Run Code Online (Sandbox Code Playgroud)
要么
temp_var[start:end, :]
Run Code Online (Sandbox Code Playgroud)
但是如何获取idx数组指示的行?有点像temp_var[idx, :]?
mrr*_*rry 10
的tf.gather()运算不正是你需要:它从一个矩阵(或一般(N-1)从一个N维张量维切片)选择行.以下是它在您的情况下的工作方式:
temp_var = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]))
idx = tf.constant([0, 2])
rows = tf.gather(temp_var, idx)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print(sess.run(rows)) # ==> [[1, 2, 3], [7, 8, 9]]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
4331 次 |
| 最近记录: |