如何从 pytorch DataLoader 获取特定样本?

MJi*_*ter 6 pytorch

在 Pytorch 中,有没有办法使用类加载特定的torch.utils.data.DataLoader单个样本?我想用它做一些测试。

教程使用

trainloader = torch.utils.data.DataLoader(...)
images, labels = next(iter(trainloader))
Run Code Online (Sandbox Code Playgroud)

获取一批随机样本。有没有办法使用DataLoader来获取特定样本?

干杯

muj*_*iga 8

  • 关闭shuffle输入DataLoader
  • 用于batch_size计算您要查找的所需样品所属的批次
  • 迭代到所需的批次

代码

import torch 
import numpy as np
import itertools

X= np.arange(100)
batch_size = 2

dataloader = torch.utils.data.DataLoader(X, batch_size=batch_size, shuffle=False)
sample_at = 5
k = int(np.floor(sample_at/batch_size))

my_sample = next(itertools.islice(dataloader, k, None))
print (my_sample)
Run Code Online (Sandbox Code Playgroud)

输出:

tensor([4, 5])
Run Code Online (Sandbox Code Playgroud)