如何实现标记嵌入的中心损失和其他运行平均值

Rob*_*obR 6 tensorflow

最近的一篇论文(这里)介绍了他们称为中心损失的二次损失函数.它基于批处理中的嵌入与每个相应类的运行平均嵌入之间的距离.TF Google小组(此处)已就如何计算和更新此类嵌入中心进行了一些讨论.我在下面的答案中汇总了一些代码来生成类平均嵌入.

这是最好的方法吗?

Rob*_*obR 5

先前发布的方法太简单,无法用于中心丢失等情况,其中嵌入的预期值随着模型变得更精细而随时间变化.这是因为先前的中心查找程序对自启动以来的所有实例进行平均,因此非常缓慢地跟踪预期值的变化.相反,移动窗口平均值是优选的.指数移动窗口变体如下:

def get_embed_centers(embed_batch, label_batch):
    ''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
    '''
    decay = 0.95
    with tf.variable_scope('embed', reuse=True):
        embed_ctrs = tf.get_variable("ctrs")

    label_batch = tf.reshape(label_batch, [-1])
    old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
    embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
    embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    return embed_ctrs_batch


with tf.Session() as sess:
    with tf.variable_scope('embed'):
        embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
    label_batch_ph = tf.placeholder(tf.int32)
    embed_batch_ph = tf.placeholder(tf.float32)
    embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
    sess.run(tf.initialize_all_variables())
    tf.get_default_graph().finalize()
Run Code Online (Sandbox Code Playgroud)