Tho*_*oth 5 python pytorch pytorch-dataloader
我想在 ImageNet 数据集(1000 个类别)上训练分类器,并且需要每个批次包含来自同一类别的 64 个图像以及来自不同类别的连续批次。到目前为止,根据我@shai的建议和这篇文章
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
class DS(Dataset):
def __init__(self, data, num_classes):
super(DS, self).__init__()
self.data = data
self.indices = [[] for _ in range(num_classes)]
for i, (data, class_label) in enumerate(data):
# create a list of lists, where every sublist containts the indices of
# the samples that belong to the class_label
self.indices[class_label].append(i)
def classes(self):
return self.indices
def __getitem__(self, index):
return self.data[index]
class BatchSampler:
def __init__(self, classes, batch_size):
# classes is a list of lists where each sublist refers to a class and contains
# the sample ids that belond to this class
self.classes = classes
self.n_batches = sum([len(x) for x in classes]) // batch_size
self.min_class_size = min([len(x) for x in classes])
self.batch_size = batch_size
self.class_range = list(range(len(self.classes)))
random.shuffle(self.class_range)
assert batch_size < self.min_class_size, 'batch_size should be at least {}'.format(self.min_class_size)
def __iter__(self):
batches = []
for j in range(self.n_batches):
if j < len(self.class_range):
batch_class = self.class_range[j]
else:
batch_class = random.choice(self.class_range)
batches.append(np.random.choice(self.classes[batch_class], self.batch_size))
return iter(batches)
def main():
# Code about
_train_dataset = DS(train_dataset, train_dataset.num_classes)
_batch_sampler = BatchSampler(_train_dataset.classes(), batch_size=args.batch_size)
_train_loader = DataLoader(dataset=_train_dataset, batch_sampler=_batch_sampler)
labels = []
for i, (inputs, _labels) in enumerate(_train_loader):
labels.append(torch.unique(_labels).item())
print("Unique labels: {}".format(torch.unique(_labels).item()))
labels = set(labels)
print('Length of traversed unique labels: {}'.format(len(labels)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
help='path to dataset (default: imagenet)')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument('-b', '--batch-size', default=64, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
args = parser.parse_args()
if args.dummy:
print("=> Dummy data is used!")
num_classes = 100
train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
num_classes=num_classes, transform=transforms.ToTensor())
val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
else:
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
# Samplers are initialized to None and train_sampler will be replaced
train_sampler, val_sampler = None, None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
main()
Run Code Online (Sandbox Code Playgroud)
打印:Length of traversed unique labels: 100.
然而,self.indices在for循环中创建需要花费大量时间。有没有更有效的方法来构建这个采样器?
编辑:产量实施
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
from tqdm import tqdm
import os.path
class DS(Dataset):
def __init__(self, data, num_classes):
super(DS, self).__init__()
self.data = data
self.data_len = len(data)
indices = [[] for _ in range(num_classes)]
for i, (_, class_label) in tqdm(enumerate(data), total=len(data), miniters=1,
desc='Building class indices dataset..'):
indices[class_label].append(i)
self.indices = indices
def per_class_sample_indices(self):
return self.indices
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.data_len
class BatchSampler:
def __init__(self, per_class_sample_indices, batch_size):
# classes is a list of lists where each sublist refers to a class and contains
# the sample ids that belond to this class
self.per_class_sample_indices = per_class_sample_indices
self.n_batches = sum([len(x) for x in per_class_sample_indices]) // batch_size
self.min_class_size = min([len(x) for x in per_class_sample_indices])
self.batch_size = batch_size
self.class_range = list(range(len(self.per_class_sample_indices)))
random.shuffle(self.class_range)
def __iter__(self):
for j in range(self.n_batches):
if j < len(self.class_range):
batch_class = self.class_range[j]
else:
batch_class = random.choice(self.class_range)
if self.batch_size <= len(self.per_class_sample_indices[batch_class]):
batch = np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size)
# batches.append(np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size))
else:
batch = self.per_class_sample_indices[batch_class]
yield batch
def n_batches(self):
return self.n_batches
def main():
file_path = 'a_file_path'
file_name = 'per_class_sample_indices.pt'
if not os.path.exists(os.path.join(file_path, file_name)):
print('File: {} does not exists. Create it.'.format(file_name))
per_class_sample_indices = DS(train_dataset, num_classes).per_class_sample_indices()
torch.save(per_class_sample_indices, os.path.join(file_path, file_name))
else:
per_class_sample_indices = torch.load(os.path.join(file_path, file_name))
print('File: {} exists. Do not create it.'.format(file_name))
batch_sampler = BatchSampler(per_class_sample_indices,
batch_size=args.batch_size)
train_loader = torch.utils.data.DataLoader(
train_dataset,
# batch_size=args.batch_size,
# shuffle=(train_sampler is None),
num_workers=args.workers,
pin_memory=True,
# sampler=train_sampler,
batch_sampler=batch_sampler
)
# We do not use sampler for the validation
# val_loader = torch.utils.data.DataLoader(
# val_dataset, batch_size=args.batch_size, shuffle=False,
# num_workers=args.workers, pin_memory=True, sampler=None)
labels = []
for i, (inputs, _labels) in enumerate(train_loader):
labels.append(torch.unique(_labels).item())
print("Unique labels: {}".format(torch.unique(_labels).item()))
labels = set(labels)
print('Length of traversed unique labels: {}'.format(len(labels)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
help='path to dataset (default: imagenet)')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument('-b', '--batch-size', default=64, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
args = parser.parse_args()
if args.dummy:
print("=> Dummy data is used!")
num_classes = 100
train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
num_classes=num_classes, transform=transforms.ToTensor())
val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
else:
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
num_classes = len(train_dataset.classes)
main()
Run Code Online (Sandbox Code Playgroud)
可以在此处找到类似的帖子,但在 TensorFlow 中
| 归档时间: |
|
| 查看次数: |
1568 次 |
| 最近记录: |