最近的一篇论文(这里)介绍了他们称为中心损失的二次损失函数.它基于批处理中的嵌入与每个相应类的运行平均嵌入之间的距离.TF Google小组(此处)已就如何计算和更新此类嵌入中心进行了一些讨论.我在下面的答案中汇总了一些代码来生成类平均嵌入.
这是最好的方法吗?
先前发布的方法太简单,无法用于中心丢失等情况,其中嵌入的预期值随着模型变得更精细而随时间变化.这是因为先前的中心查找程序对自启动以来的所有实例进行平均,因此非常缓慢地跟踪预期值的变化.相反,移动窗口平均值是优选的.指数移动窗口变体如下:
def get_embed_centers(embed_batch, label_batch):
''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
'''
decay = 0.95
with tf.variable_scope('embed', reuse=True):
embed_ctrs = tf.get_variable("ctrs")
label_batch = tf.reshape(label_batch, [-1])
old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
return embed_ctrs_batch
with tf.Session() as sess:
with tf.variable_scope('embed'):
embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
initializer=tf.constant_initializer(0), trainable=False)
label_batch_ph = tf.placeholder(tf.int32)
embed_batch_ph = tf.placeholder(tf.float32)
embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
sess.run(tf.initialize_all_variables())
tf.get_default_graph().finalize()
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1385 次 |
| 最近记录: |