循环张量并将函数应用于每个元素

xuh*_*hai 5 python tensorflow tensor

我想遍历一个包含 的张量Int,并将一个函数应用于每个元素。在函数中,每个元素都将从 python 的字典中获取值。我已经尝试了简单的方法 with tf.map_fn,它可以处理add函数,例如以下代码:

import tensorflow as tf

def trans_1(x):
    return x+10

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))
# output: [11 12 13]
Run Code Online (Sandbox Code Playgroud)

但以下代码抛出KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32异常:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))
Run Code Online (Sandbox Code Playgroud)

我的 tensorflow 版本是1.13.1. 先谢谢了。

jde*_*esa 0

您不能使用这样的函数,因为参数x是 TensorFlow 张量,而不是 Python 值。因此,为了使其发挥作用,您还必须将字典转换为张量,但这并不是那么简单,因为字典中的键可能不是连续的。

您可以在不进行映射的情况下解决此问题,而是执行类似于此处为 NumPy 建议的操作。在 TensorFlow 中,您可以这样实现:

import tensorflow as tf

def replace_by_dict(x, d):
    # Get keys and values from dictionary
    keys, values = zip(*d.items())
    keys = tf.constant(keys, x.dtype)
    values = tf.constant(values, x.dtype)
    # Make a sequence for the range of values in the input
    v_min = tf.reduce_min(x)
    v_max = tf.reduce_max(x)
    r = tf.range(v_min, v_max + 1)
    r_shape = tf.shape(r)
    # Mask replacements that are out of the input range
    mask = (keys >= v_min) & (keys <= v_max)
    keys = tf.boolean_mask(keys, mask)
    values = tf.boolean_mask(values, mask)
    # Replace values in the sequence with the corresponding replacements
    scatter_idx = tf.expand_dims(keys, 1) - v_min
    replace_mask = tf.scatter_nd(
        scatter_idx, tf.ones_like(values, dtype=tf.bool), r_shape)
    replace_values = tf.scatter_nd(scatter_idx, values, r_shape)
    replacer = tf.where(replace_mask, replace_values, r)
    # Gather the replacement value or the same value if it was not modified
    return tf.gather(replacer, x - v_min)

# Test
kv_dict = {1: 11, 2: 12, 3: 13}
with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.constant([1, 2, 3])
    print(sess.run(replace_by_dict(a, kv_dict)))
    # [11, 12, 13]
Run Code Online (Sandbox Code Playgroud)

这将允许您在输入张量中拥有无需替换的值(保持原样),并且也不需要在张量中拥有所有替换值。除非输入中的最小值和最大值相差很远,否则它应该是有效的。