如何为张量流中的张量的某些条目停止梯度

Jer*_*Eph 1 deep-learning tensorflow

我正在尝试实现嵌入层.嵌入将使用预先训练的手套嵌入进行初始化.对于可以在手套中找到的单词,它将被修复.对于那些没有出现在手套中的单词,它将随机初始化,并且可以训练.我如何在tensorflow中做到这一点?我知道整个张量都有一个tf.stop_gradient,对于这种场景有什么样的stop_gradient api吗?或者,有什么解决方法吗?任何建议表示赞赏

Jer*_*Eph 12

所以我的想法是使用masktf.stop_gradient解决这个问题:

res_matrix = tf.stop_gradient(mask_h*E) + mask*E,

在矩阵中mask,1表示我想要应用渐变的条目,0表示我不想应用渐变的条目(将渐变设置为0),mask_hmask(1翻转为0,0翻转为1 )的投影然后,我们可以从中获取res_matrix.这是测试代码:

import tensorflow as tf
import numpy as np

def entry_stop_gradients(target, mask):
    mask_h = tf.abs(mask-1)
    return tf.stop_gradient(mask_h * target) + mask * target

mask = np.array([1., 0, 1, 1, 0, 0, 1, 1, 0, 1])
mask_h = np.abs(mask-1)

emb = tf.constant(np.ones([10, 5]))

matrix = entry_stop_gradients(emb, tf.expand_dims(mask,1))

parm = np.random.randn(5, 1)
t_parm = tf.constant(parm)

loss = tf.reduce_sum(tf.matmul(matrix, t_parm))
grad1 = tf.gradients(loss, emb)
grad2 = tf.gradients(loss, matrix)
print matrix
with tf.Session() as sess:
    print sess.run(loss)
    print sess.run([grad1, grad2])
Run Code Online (Sandbox Code Playgroud)