我有一个张量,我使用tf.map_fn. 现在我想将索引作为参数包含在我传递给的函数中tf.map_fn。在 numpy 中,我可以enumerate用来获取该信息并将其传递到我的 lambda 函数中。这是 numpy 中的一个示例,我将 0 添加到第一行,将 1 添加到第二行,依此类推:
a = np.array([[2, 1], [4, 2], [-1, 2]])
def some_function(x, i):
return x + i
res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)
> [array([2, 1]), array([5, 3]), array([1, 4])]
Run Code Online (Sandbox Code Playgroud)
我一直无法enumerate在 tensorflow 中找到等效项,我不知道如何在 tensorflow 中获得相同的结果。有人知道用什么让它在 tensorflow 中工作吗?这是一个示例代码,我在其中的每一行中添加 1 a:
import tensorflow as tf
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
with tf.Session() as sess:
res = tf.map_fn(lambda row: some_function(row, 1), a)
print(res.eval())
> [[3 2]
[5 3]
[0 3]]
Run Code Online (Sandbox Code Playgroud)
感谢任何可以帮助我解决这个问题的人。
tf.map_fn()可以有很多输入/输出。因此,您可以使用tf.range()来构建行索引张量并将其用于:
import tensorflow as tf
def some_function(x, i):
return x + i
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res, _ = tf.map_fn(lambda x: (some_function(x[0], x[1]), x[1]),
(a, a_rows), dtype=(tf.int32, tf.int32))
with tf.Session() as sess:
print(res.eval())
# [[2 1]
# [5 3]
# [1 4]]
Run Code Online (Sandbox Code Playgroud)
注意:在许多情况下,“逐行处理矩阵”可以立即完成,例如通过广播,而不是使用循环:
import tensorflow as tf
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res = a + a_rows
with tf.Session() as sess:
print(res.eval())
# [[2 1]
# [5 3]
# [1 4]]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3387 次 |
| 最近记录: |