仅更新Tensorflow中单词嵌入矩阵的一部分

use*_*064 10 tensorflow word-embedding

假设我想在训练期间更新预训练的字嵌入矩阵,有没有办法只更新字嵌入矩阵的子集?

我查看了Tensorflow API页面,发现了这个:

# Create an optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1)

# Compute the gradients for a list of variables.
grads_and_vars = opt.compute_gradients(loss, <list of variables>)

# grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1])) for gv in grads_and_vars]

# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)
Run Code Online (Sandbox Code Playgroud)

但是,我如何将其应用于字嵌入矩阵.假设我这样做:

word_emb = tf.Variable(0.2 * tf.random_uniform([syn0.shape[0],s['es']], minval=-1.0, maxval=1.0, dtype=tf.float32),name='word_emb',trainable=False)

gather_emb = tf.gather(word_emb,indices) #assuming that I pass some indices as placeholder through feed_dict

opt = tf.train.AdamOptimizer(1e-4)
grad = opt.compute_gradients(loss,gather_emb)
Run Code Online (Sandbox Code Playgroud)

然后我如何使用opt.apply_gradientstf.scatter_update更新原始的embeddign矩阵?(另外,如果第二个参数compute_gradient不是a ,则tensorflow会抛出错误tf.Variable)

mrr*_*rry 19

TL; DR:opt.minimize(loss) TensorFlow 的默认实现将生成稀疏更新,word_emb该更新仅修改word_emb参与正向传递的行.

tf.gather(word_emb, indices)op相对于的梯度word_emb是一个tf.IndexedSlices对象(有关更多详细信息,请参阅实现).此对象表示稀疏张量,除了选择的行外,其他位置均为零indices.对调用的opt.minimize(loss)调用AdamOptimizer._apply_sparse(word_emb_grad, word_emb),调用tf.scatter_sub(word_emb, ...)*仅更新word_emb由其选择的行indices.

如果在另一方面,你要修改的tf.IndexedSlices是由返回opt.compute_gradients(loss, word_emb),你可以在它执行任意TensorFlow操作indicesvalues性能,并创建一个新的tf.IndexedSlices可传递到opt.apply_gradients([(word_emb, ...)]).例如,您可以MyCapper()使用以下调用使用(如示例中)调整渐变:

grad, = opt.compute_gradients(loss, word_emb)
train_op = opt.apply_gradients(
    [tf.IndexedSlices(MyCapper(grad.values), grad.indices)])
Run Code Online (Sandbox Code Playgroud)

同样,您可以通过创建tf.IndexedSlices具有不同索引的new来更改将要修改的索引集.


*在一般情况下,如果要更新TensorFlow变量的只是一部分,你可以使用tf.scatter_update(),tf.scatter_add()tf.scatter_sub()运营商,它们分别设置,增加(+=)或(减-=)值预先存储在变量.


Ore*_*ren 6

由于您只想选择要更新的元素(而不是更改渐变),您可以执行以下操作.

我们indices_to_update是一个布尔张量表示要更新索引,并entry_stop_gradients在链接,则定义为:

gather_emb = entry_stop_gradients(gather_emb, indices_to_update)
Run Code Online (Sandbox Code Playgroud)

(来源)