use*_*064 10 tensorflow word-embedding
假设我想在训练期间更新预训练的字嵌入矩阵,有没有办法只更新字嵌入矩阵的子集?
我查看了Tensorflow API页面,发现了这个:
# Create an optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1)
# Compute the gradients for a list of variables.
grads_and_vars = opt.compute_gradients(loss, <list of variables>)
# grads_and_vars is a list of tuples (gradient, variable). Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1])) for gv in grads_and_vars]
# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)
Run Code Online (Sandbox Code Playgroud)
但是,我如何将其应用于字嵌入矩阵.假设我这样做:
word_emb = tf.Variable(0.2 * tf.random_uniform([syn0.shape[0],s['es']], minval=-1.0, maxval=1.0, dtype=tf.float32),name='word_emb',trainable=False)
gather_emb = tf.gather(word_emb,indices) #assuming that I pass some indices as placeholder through feed_dict
opt = tf.train.AdamOptimizer(1e-4)
grad = opt.compute_gradients(loss,gather_emb)
Run Code Online (Sandbox Code Playgroud)
然后我如何使用opt.apply_gradients和tf.scatter_update更新原始的embeddign矩阵?(另外,如果第二个参数compute_gradient不是a ,则tensorflow会抛出错误tf.Variable)
mrr*_*rry 19
TL; DR:opt.minimize(loss) TensorFlow 的默认实现将生成稀疏更新,word_emb该更新仅修改word_emb参与正向传递的行.
tf.gather(word_emb, indices)op相对于的梯度word_emb是一个tf.IndexedSlices对象(有关更多详细信息,请参阅实现).此对象表示稀疏张量,除了选择的行外,其他位置均为零indices.对调用的opt.minimize(loss)调用AdamOptimizer._apply_sparse(word_emb_grad, word_emb),调用tf.scatter_sub(word_emb, ...)*仅更新word_emb由其选择的行indices.
如果在另一方面,你要修改的tf.IndexedSlices是由返回opt.compute_gradients(loss, word_emb),你可以在它执行任意TensorFlow操作indices和values性能,并创建一个新的tf.IndexedSlices可传递到opt.apply_gradients([(word_emb, ...)]).例如,您可以MyCapper()使用以下调用使用(如示例中)调整渐变:
grad, = opt.compute_gradients(loss, word_emb)
train_op = opt.apply_gradients(
[tf.IndexedSlices(MyCapper(grad.values), grad.indices)])
Run Code Online (Sandbox Code Playgroud)
同样,您可以通过创建tf.IndexedSlices具有不同索引的new来更改将要修改的索引集.
*在一般情况下,如果要更新TensorFlow变量的只是一部分,你可以使用tf.scatter_update(),tf.scatter_add()或tf.scatter_sub()运营商,它们分别设置,增加(+=)或(减-=)值预先存储在变量.
由于您只想选择要更新的元素(而不是更改渐变),您可以执行以下操作.
我们indices_to_update是一个布尔张量表示要更新索引,并entry_stop_gradients在链接,则定义为:
gather_emb = entry_stop_gradients(gather_emb, indices_to_update)
Run Code Online (Sandbox Code Playgroud)
(来源)
| 归档时间: |
|
| 查看次数: |
3646 次 |
| 最近记录: |