在单 GPU 上使用 Pytorch 并行化简单的 for 循环

Pra*_*tik 5 python numpy pytorch

我有一个 for 循环,它对一个大矩阵的独立列进行操作。我使用 Numba 中的 prange 函数在 CPU 上并行化了 for 循环。现在我想在 GPU 上使用 PyTorch 张量执行此操作。我是 PyTorch 新手,不知道该怎么做。

任何帮助将不胜感激。

我的Python代码如下:

def select_next(X, gains, current_values, mask):
    for idx in prange(X.shape[0]):
        if mask[idx] == 1:
            continue

        a = numpy.maximum(X[idx], current_values)
        gains[idx] = (a - current_values).sum()
    return numpy.argmax(gains)
Run Code Online (Sandbox Code Playgroud)

我的PyTorch代码如下:

def select_next(X, gains, current_values, mask):
    for idx in range(X.shape[0]):
        if mask[idx].item() == 1:
            continue

        a = torch.max(X[idx], current_values)
        gains[idx] = torch.sum(torch.sub(a, current_values))
    return torch.argmax(gains)
Run Code Online (Sandbox Code Playgroud)

如何并行化 for 循环?