在Tensorflow中,如何解析由tf.nn.max_pool_with_argmax获得的扁平指数?

kar*_*TUM 6 python deep-learning tensorflow

我遇到一个问题:在我使用tf.nn.max_pool_with_argmax之后,我获得了索引,即 argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max values chosen for each output.

如何将展平的索引解开回Tensorflow中的坐标列表?

非常感谢你.

Fab*_*ian 3

我今天遇到了同样的问题,最终得到了这个解决方案:

def unravel_argmax(argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.pack(output_list)
Run Code Online (Sandbox Code Playgroud)

这是ipython 笔记本中的使用示例(我用它来将池 argmax 位置转发到我的解池方法)