我想知道下面的代码是否有更有效的替代方案,而不在第四行中使用“for”循环?
import torch
n, d = 37700, 7842
k = 4
sample = torch.cat([torch.randperm(d)[:k] for _ in range(n)]).view(n, k)
mask = torch.zeros(n, d, dtype=torch.bool)
mask.scatter_(dim=1, index=sample, value=True)
Run Code Online (Sandbox Code Playgroud)
基本上,我想做的是创建一个n
掩码d
张量,使得每一行中的k
随机元素都是 True。
这是一种无需循环即可完成此操作的方法。让我们从一个随机矩阵开始,其中所有元素都是独立同分布的,在本例中均匀地在 [0,1] 上。然后我们取每行的第 k 个分位数,并将每行上所有较小或相等的元素设置为 True,其余元素设置为 False:
rand_mat = torch.rand(n, d)
k_th_quant = torch.topk(rand_mat, k, largest = False)[0][:,-1:]
mask = rand_mat <= k_th_quant
Run Code Online (Sandbox Code Playgroud)
无需循环:) 比您在我的 CPU 上附加的代码快 x2.1598。
归档时间: |
|
查看次数: |
1874 次 |
最近记录: |