张量流中的排列优化

Kh4*_*tiK 5 algorithm permutation tensorflow

问题设置

我有一批4x4矩阵,包含实值条目。

L = tf.placeholder('float32', shape=[None, 4, 4], name='pairwise-loss')
Run Code Online (Sandbox Code Playgroud)

我想找到一批 4-permutations,使得每个 4-permutation 最小化将排列作为单热掩码应用的总和。

def get_best_permutation(L_):
    '''
    Returns a batch of permutations `P` such that `tf.reduce_sum(tf.one_hot(P) * L_, axis=(1,2))`
      is minimized for each batch element.

    Args:
        L_: real valued tensor of shape [None, 4, 4]

    Returns:
        P: tf.int64 tensor of shape [None, 4]

    '''
    raise NotImpelmentedError()
Run Code Online (Sandbox Code Playgroud)

排列大小小且恒定,通常为 3-4。然而,批量大小预计会非常大。理想情况下,一切都应该在图形内完成,并在 GPU 上完成,因此数据传输更少。

编辑假设所有矩阵条目都是正数是安全的。

背景

这是为了实现类似于Permutation Invariant Training 的东西。

快速而肮脏的解决方案

可以预先计算所有可能的排列,因为它们很小,然后并行应用所有排列。最后申请tf.argmin找到最好的。但是我想要一个更有效的解决方案。