用 Pytorch 随机选择?

Nic*_*ais 15 python numpy machine-learning python-3.x pytorch

我有一个张量的图片,想从中随机选择。我正在寻找相当于np.random.choice().

import torch

pictures = torch.randint(0, 256, (1000, 28, 28, 3))
Run Code Online (Sandbox Code Playgroud)

假设我想要 10 张这样的照片。

Nic*_*ais 33

torch没有等效的实现np.random.choice(),请参阅此处的讨论。另一种方法是使用混洗索引或随机整数进行索引。

替换来做到这一点:

  1. 生成n随机索引
  2. 用这些索引索引你的原始张量
pictures[torch.randint(len(pictures), (10,))]  
Run Code Online (Sandbox Code Playgroud)

要做到这一点而无需更换:

  1. 洗牌索引
  2. 取第n个元素
indices = torch.randperm(len(pictures))[:10]

pictures[indices]
Run Code Online (Sandbox Code Playgroud)

阅读更多关于torch.randinttorch.randperm。第二个代码片段的灵感来自PyTorch 论坛中的这篇文章


uke*_*emi 12

torch.multinomial提供与 numpy 等效的行为random.choice(包括带/不带替换的采样):

# Uniform weights for random draw
unif = torch.ones(pictures.shape[0])

idx = unif.multinomial(10, replacement=True)
samples = pictures[idx]
Run Code Online (Sandbox Code Playgroud)
samples.shape
>>> torch.Size([10, 28, 28, 3])
Run Code Online (Sandbox Code Playgroud)


小智 7

对于这个大小的张量:

N, D = 386363948, 2
k = 190973
values = torch.randn(N, D)
Run Code Online (Sandbox Code Playgroud)

下面的代码运行得相当快。大约需要0.2秒:

indices = torch.tensor(random.sample(range(N), k))
indices = torch.tensor(indices)
sampled_values = values[indices]
Run Code Online (Sandbox Code Playgroud)

torch.randperm然而,使用会花费 20 多秒:

sampled_values = values[torch.randperm(N)[:k]]
Run Code Online (Sandbox Code Playgroud)