Tensorflow多维argmax

Ras*_*mus 7 tensorflow

假设我有一个大小为BxWxHxD的张量.我要处理的张量,使得我有一个新的BxWxHxD张量仅在每个宽x高片的最大元素保持,而所有其他值都为零.

换句话说,我认为实现这一目标的最佳方法是以某种方式在WxH切片上采用2D argmax,从而产生行和列的BxD索引张量,然后可以将其转换为单热BxWxHxD张量以用作一张面具.我该如何工作?

ond*_*jba 8

您可以使用以下函数作为起点。它计算每个批次和每个通道的最大元素的索引。结果数组的格式为 (batch size, 2, number of channels)。

def argmax_2d(tensor):

  # input format: BxHxWxD
  assert rank(tensor) == 4

  # flatten the Tensor along the height and width axes
  flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3]))

  # argmax of the flat tensor
  argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32)

  # convert indexes into 2D coordinates
  argmax_x = argmax // tf.shape(tensor)[2]
  argmax_y = argmax % tf.shape(tensor)[2]

  # stack and return 2D coordinates
  return tf.stack((argmax_x, argmax_y), axis=1)

def rank(tensor):

  # return the rank of a Tensor
  return len(tensor.get_shape())
Run Code Online (Sandbox Code Playgroud)