在 Numpy/PyTorch 中快速查找值大于阈值的索引

Gab*_*Chu 4 python numpy pytorch

任务

给定一个 numpy 或 pytorch 矩阵,找到值大于给定阈值的单元格的索引。

我的实现

#abs_cosine is the matrix
#sim_vec is the wanted

sim_vec = []
for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
        # exclude diagonal cells
        if m != n and abs_cosine[m][n] >= threshold:
            sim_vec.append((m, n))
Run Code Online (Sandbox Code Playgroud)

顾虑

速度。所有其他计算都建立在 Pytorch 上,使用numpy已经是一种妥协,因为它已经将计算从 GPU 转移到了 CPU。纯 pythonfor循环会使整个过程变得更糟(对于小数据集已经慢了 5 倍)。我想知道我们是否可以在不调用任何for循环的情况下将整个计算移动到 Numpy(或 pytorch)?

我能想到的改进(但卡住了......)

bool_cosine = abs_cosine > 阈值

它返回一个True和的布尔矩阵False。但是我找不到快速检索True单元格索引的方法。

lay*_*yog 5

以下是 PyTorch(完全在 GPU 上)

# abs_cosine should be a Tensor of shape (m, m)
mask = torch.ones(abs_cosine.size()[0])
mask = 1 - mask.diag()
sim_vec = torch.nonzero((abs_cosine >= threshold)*mask)

# sim_vec is a tensor of shape (?, 2) where the first column is the row index and second is the column index
Run Code Online (Sandbox Code Playgroud)

以下在 numpy 中有效

mask = 1 - np.diag(np.ones(abs_cosine.shape[0]))
sim_vec = np.nonzero((abs_cosine >= 0.2)*mask)
# sim_vec is a 2-array tuple where the first array is the row index and the second array is column index
Run Code Online (Sandbox Code Playgroud)