如何使用 Pytorch 和/或 Numpy 高效查找多维矩阵数组中最大值的索引

Ser*_*ing 5 python numpy max numba pytorch

背景

处理高维数据在机器学习中很常见。例如,在卷积神经网络 (CNN) 中,每个输入图像的尺寸可以是 256x256,并且每个图像可以具有 3 个颜色通道(红色、绿色和蓝色)。如果我们假设模型一次接收一批 16 张图像,则进入 CNN 的输入的维度为[16,3,256,256]。每个单独的卷积层都期望 形式的数据[batch_size, in_channels, in_y, in_x],并且所有这些数量通常会逐层变化(batch_size 除外)。我们用于表示由值组成的矩阵的术语[in_y, in_x]“特征映射”,这个问题涉及在给定层的每个特征映射中查找最大值及其索引。

我为什么要这样做?我想对每个特征图应用一个掩码,并且我想应用以每个特征图中的最大值为中心的掩码,为此,我需要知道每个最大值所在的位置。这种掩模应用是在模型的训练和测试期间完成的,因此效率对于减少计算时间至关重要。有许多 Pytorch 和 Numpy 解决方案可用于查找单例最大值和索引,以及查找沿单个维度的最大值或索引,但没有(我能找到)专用且高效的内置函数来查找最大值的索引一次沿着 2 个或更多维度。是的,我们可以嵌套在单个维度上运行的函数,但这些是一些效率最低的方法。

我尝试过的

  • 我看过这个 Stackoverflow 问题,但作者正在处理一个特殊情况的 4D 数组,它被简单地压缩为 3D 数组。接受的答案是专门针对这种情况的,而指向 TopK 的答案是误导性的,因为它不仅在单个维度上运行,而且k=1根据所提出的问题需要这样做,从而发展为常规torch.max调用。
  • 我看过这个 Stackoverflow 问题,但是这个问题及其答案,重点关注单一维度。
  • 我已经看过这个 Stackoverflow 问题,但我已经知道答案的方法,因为我在自己的答案中独立地表述了它我修改了该方法非常低效)。
  • 我看过这个Stackoverflow问题,但它不满足这个问题的关键部分,即与效率有关。
  • 我阅读了许多其他 Stackoverflow 问题和答案,以及 Numpy 文档、Pytorch 文档和 Pytorch 论坛上的帖子。
  • 我已经尝试实施很多不同的方法来解决这个问题,足以让我创建这个问题,以便我可以回答它并回馈社区以及将来寻找此问题解决方案的任何人。

绩效标准

如果我问有关效率的问题,我需要清楚地详细说明期望。我正在尝试为上述问题找到一种省时的解决方案(空间是次要的),而无需编写 C 代码/扩展,并且该解决方案相当灵活(超级专业的方法不是我所追求的)。该方法必须接受[a,b,c,d]数据类型 float32 或 float64 的 Torch 张量作为输入,并输出数据类型 int32 或 int64 形式的数组或张量[a,b,2](因为我们使用输出作为索引)。解决方案应根据以下典型解决方案进行基准测试:

max_indices = torch.stack([torch.stack([(x[k][j]==torch.max(x[k][j])).nonzero()[0] for j in range(x.size()[1])]) for k in range(x.size()[0])])
Run Code Online (Sandbox Code Playgroud)

Ser*_*ing 4

该方法

我们将利用 Numpy 社区和库,以及 Pytorch 张量和 Numpy 数组可以相互转换,而无需复制或移动内存中的底层数组(因此转换成本较低)。来自Pytorch 文档

将 torch 张量转换为 Numpy 数组,反之亦然,非常简单。torch Tensor 和 Numpy 数组将共享它们的底层内存位置,改变一个就会改变另一个。

解决方案一

我们首先将使用Numba 库编写一个函数,该函数将在首次使用时进行即时 (JIT) 编译,这意味着我们无需自己编写 C 代码即可获得 C 速度。当然,对于什么可以进行 JIT 编辑,有一些注意事项,其中之一是我们使用 Numpy 函数。但这还不算太糟糕,因为请记住,从我们的 torch 张量转换为 Numpy 的成本很低。我们创建的函数是:

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
Run Code Online (Sandbox Code Playgroud)

这个函数来自位于此处的另一个 Stackoverflow 答案(这是向我介绍 Numba 的答案)。该函数采用 N 维 Numpy 数组并查找给定 的第一次出现item。如果成功匹配,它会立即返回找到的项目的索引。装饰@njit器是 的缩写@jit(nopython=True),告诉编译器我们希望它使用Python 对象来编译函数,如果不能这样做则抛出错误(当不使用 Python 对象时,Numba 是最快的,并且速度更快)就是我们所追求的)。

有了这个快速函数的支持,我们可以得到张量中最大值的索引,如下所示:

import numpy as np

x =  x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = np.zeros((n,p,2),dtype=np.int64)
for index in np.ndindex(x.shape[0],x.shape[1]):
    max_indices[index] = np.asarray(indexFunc(x[index], maxVals[index]),dtype=np.int64)
max_indices = torch.from_numpy(max_indices)
Run Code Online (Sandbox Code Playgroud)

我们使用它是np.amax因为它可以接受一个元组作为其axis参数,从而允许它返回 4D 输入中每个 2D 特征图的最大值。我们提前初始化,max_indices因为np.zeros追加到 numpy 数组的成本很高,所以我们提前分配所需的空间。这种方法比问题中的典型解决方案快得多(快一个数量级),但它还使用forJIT 函数外部的循环,因此我们可以改进...

解决方案二

我们将使用以下解决方案:

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc2(x,maxVals):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

x = x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = torch.from_numpy(indexFunc2(x,maxVals))
Run Code Online (Sandbox Code Playgroud)

for我们可以使用 Numbaprange函数(其行为完全相同,range但告诉编译器我们希望循环并行化)和parallel=True装饰器参数来利用并行化,而不是使用循环一次一个地迭代我们的特征映射。Numba 还并行化该np.zeros功能。因为我们的函数是即时编译的并且不使用 Python 对象,所以 Numba 可以利用我们系统中所有可用的线程!值得注意的是,现在有raise RuntimeError一个indexFunc. 我们需要包含它,否则 Numba 编译器将尝试推断函数的返回类型并推断它将是数组或 None。这与我们在 中的用法不一致indexFunc2,因此编译器会抛出错误。当然,从我们的设置中我们知道indexFunc总是会返回一个数组,因此我们可以简单地在另一个逻辑分支中引发错误。

此方法在功能上与解决方案一相同,但将迭代使用更改nd.index为使用的两个for循环prange。这种方法比解决方案一快大约 4 倍。

解决方案三

解决方案二速度很快,但它仍然使用常规 Python 查找最大值。我们可以使用更全面的 JIT 函数来加速这一过程吗?

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc3(x):
    maxVals = np.zeros((x.shape[0],x.shape[1]),dtype=np.float32)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxVals[i][j] = np.max(x[i][j])
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            x[i][j] == np.max(x[i][j])
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

max_indices = torch.from_numpy(indexFunc3(x))
Run Code Online (Sandbox Code Playgroud)

看起来这个解决方案中还发生了很多事情,但唯一的变化是,np.amax我们现在已经并行化了操作,而不是使用 计算每个特征图的最大值。这种方法比解决方案二稍微快一些。

解决方案四

这个解决方案是我能想到的最好的解决方案:

@njit(cache=True, parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices

max_indices = torch.from_numpy(indexFunc4(x))
Run Code Online (Sandbox Code Playgroud)

这种方法更加简洁,也是最快的,比解决方案三快 33%,比典型解决方案快 50 倍。我们用来np.argmax获取每个特征图最大值的索引, np.argmax只返回索引,就好像每个特征图被展平一样。也就是说,我们得到一个整数,告诉我们该元素在特征图中的编号,而不是我们需要能够访问该元素的索引。数学[maxTemp // x.shape[2], maxTemp % x.shape[2]]是将单数整数转换为[row,column]我们需要的。

标杆管理

所有方法都针对形状 的随机输入进行基准测试[32,d,64,64],其中 d 从 5 增加到 245。对于每个 d,收集 15 个样本并对时间进行平均。平等测试确保所有解决方案提供相同的值。基准输出的一个示例是:

解决方案基准

随着 d 增加,基准测试时间的图为(省略典型解决方案,因此图表不会被压扁):

基准图

哇!这些尖峰一开始发生了什么?

解决方案五

Numba 允许我们生成即时编译的函数,但直到我们第一次使用它们时它才会编译它们;然后它会缓存结果,以便我们再次调用该函数时使用。这意味着我们第一次调用 JIT 函数时,在编译函数时,计算时间会出现峰值。幸运的是,有一种方法可以解决这个问题——如果我们提前指定函数的返回类型和参数类型,函数将被急切地编译,而不是即时编译。将这些知识应用到解决方案四中,我们得到:

@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices    

max_indices6 = torch.from_numpy(indexFunc4(x))
Run Code Online (Sandbox Code Playgroud)

如果我们重新启动内核并重新运行基准测试,我们可以查看第一个结果 whered==5和第二个结果 where d==10,并注意到所有 JIT 解决方案都较慢,d==5因为它们必须进行编译,解决方案四除外,因为我们提前明确提供了函数签名:

在此输入图像描述

我们开始吧!这是迄今为止我对这个问题的最佳解决方案。


编辑#1

解决方案六

我们开发了一种改进的解决方案,比之前发布的最佳解决方案快 33%。此解决方案仅在输入数组是 C 连续的情况下才有效,但这并不是一个很大的限制,因为 numpy 数组或 torch 张量将是连续的,除非它们被重新整形,并且两者都具有在需要时使数组/张量连续的函数。

该解决方案与之前的最佳解决方案相同,但指定输入和返回类型的函数装饰器从

@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)

@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)

唯一的区别是:每个数组中的最后一个类型变为::1,这向 numba njit 编译器发出信号,表明输入数组是 C 连续的,从而使其能够更好地优化。

完整的解决方案六是:

@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)
def indexFunc5(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices 

max_indices7 = torch.from_numpy(indexFunc5(x))
Run Code Online (Sandbox Code Playgroud)

包括这个新解决方案的基准测试证实了加速:

基准测试包括解决方案 6