如何在 PyTorch 中的张量的每一行中随机设置固定数量的元素

sis*_*man 3 pytorch

我想知道下面的代码是否有更有效的替代方案,而不在第四行中使用“for”循环?

import torch
n, d = 37700, 7842
k = 4
sample = torch.cat([torch.randperm(d)[:k] for _ in range(n)]).view(n, k)
mask = torch.zeros(n, d, dtype=torch.bool)
mask.scatter_(dim=1, index=sample, value=True)
Run Code Online (Sandbox Code Playgroud)

基本上,我想做的是创建一个n掩码d张量,使得每一行中的k随机元素都是 True。

Gil*_*sky 6

这是一种无需循环即可完成此操作的方法。让我们从一个随机矩阵开始,其中所有元素都是独立同分布的,在本例中均匀地在 [0,1] 上。然后我们取每行的第 k 个分位数,并将每行上所有较小或相等的元素设置为 True,其余元素设置为 False:

rand_mat = torch.rand(n, d)
k_th_quant = torch.topk(rand_mat, k, largest = False)[0][:,-1:]
mask = rand_mat <= k_th_quant
Run Code Online (Sandbox Code Playgroud)

无需循环:) 比您在我的 CPU 上附加的代码快 x2.1598。