相关疑难解决方法(0)

在Tensorflow中计算批量中的成对距离而不复制张量?

我想计算Tensorflow中一批特征的成对平方距离.我通过平铺原始张量使用+和*操作有一个简单的实现:

def pairwise_l2_norm2(x, y, scope=None):
    with tf.op_scope([x, y], scope, 'pairwise_l2_norm2'):
        size_x = tf.shape(x)[0]
        size_y = tf.shape(y)[0]
        xx = tf.expand_dims(x, -1)
        xx = tf.tile(xx, tf.pack([1, 1, size_y]))

        yy = tf.expand_dims(y, -1)
        yy = tf.tile(yy, tf.pack([1, 1, size_x]))
        yy = tf.transpose(yy, perm=[2, 1, 0])

        diff = tf.sub(xx, yy)
        square_diff = tf.square(diff)

        square_dist = tf.reduce_sum(square_diff, 1)

        return square_dist
Run Code Online (Sandbox Code Playgroud)

该函数将两个大小为(m,d)和(n,d)的矩阵作为输入,并计算每个行向量之间的平方距离.输出是大小为(m,n)的矩阵,其元素为'd_ij = dist(x_i,y_j)'.

问题是我有一个大批量和高昏暗的功能'm,n,d'复制张量消耗了大量的内存.我正在寻找另一种方法来实现它,而不增加内存使用量,只是存储最终的距离张量.一种双循环原始张量.

python tensorflow

30
推荐指数
3
解决办法
1万
查看次数

标签 统计

python ×1

tensorflow ×1