你如何改变Pytorch数据集的大小?

mik*_*305 6 python machine-learning dataset torch pytorch

假设我从torchvision.datasets.MNIST加载MNIST,但我只想加载10000个图像,我如何切片数据以将其限制为只有一些数据点?我知道DataLoader是一个生成器,产生的数据大小与指定的批量大小相同,但是如何对数据集进行切片?

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)
train_loader = DataLoader(tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
Run Code Online (Sandbox Code Playgroud)

sri*_*gde 11

切片数据集的另一种快速方法是使用torch.utils.data.random_split()(在 PyTorch v0.4.1+ 中支持)。它有助于将数据集随机拆分为给定长度的不重叠的新数据集。

所以我们可以有如下内容:

tr = datasets.MNIST('../data', train=True, download=True, transform=transform)
te = datasets.MNIST('../data', train=False, transform=transform)

part_tr = torch.utils.data.random_split(tr, [tr_split_len, len(tr)-tr_split_len])[0]
part_te = torch.utils.data.random_split(te, [te_split_len, len(te)-te_split_len])[0]

train_loader = DataLoader(part_tr, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
test_loader = DataLoader(part_te, batch_size=args.batch_size, shuffle=True, num_workers=4, **kwargs)
Run Code Online (Sandbox Code Playgroud)

在这里,您可以分别设置tr_split_lente_split_len作为训练和测试数据集所需的分割长度。


ent*_*phy 8

重要的是要注意,当您创建DataLoader对象时,它不会立即加载您的所有数据(对于大型数据集来说,这是不切实际的).它为您提供了一个可用于访问每个样本的迭代器.

不幸的是,并DataLoader没有为您提供任何方法来控制您想要提取的样本数量.您将不得不使用切片迭代器的典型方法.

最简单的事情(没有任何库)将在达到所需数量的样本后停止.

nsamples = 10000
for i, image, label in enumerate(train_loader):
    if i > nsamples:
        break

    # Your training code here.
Run Code Online (Sandbox Code Playgroud)

或者,您可以使用itertools.islice获得前10k样本.像这样.

for image, label in itertools.islice(train_loader, stop=10000):

    # your training code here.
Run Code Online (Sandbox Code Playgroud)

  • 此方法的警告:如果您在变量“epoch”上循环多次迭代“train_loader”,您可能已经使用了训练的所有样本...因为“shuffle=True”选项`DataLoader`` 将为每个时期的样本进行洗牌。 (3认同)

uke*_*emi 8

torch.utils.data.Subset()例如,您可以使用前 10,000 个元素:

import torch.utils.data as data_utils

indices = torch.arange(10000)
tr_10k = data_utils.Subset(tr, indices)
Run Code Online (Sandbox Code Playgroud)

  • 这修改了Dataset而不是DataLoader并且非常清楚。 (3认同)