使用 PyTorch 在每次迭代中仅从一个类中高效地采样批次

Tho*_*oth 5 python pytorch pytorch-dataloader

我想在 ImageNet 数据集(1000 个类别)上训练分类器,并且需要每个批次包含来自同一类别的 64 个图像以及来自不同类别的连续批次。到目前为止,根据我@shai的建议和这篇文章

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os


class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(num_classes)]
        for i, (data, class_label) in enumerate(data):
            # create a list of lists, where every sublist containts the indices of
            # the samples that belong to the class_label
            self.indices[class_label].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]


class BatchSampler:
    def __init__(self, classes, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.min_class_size = min([len(x) for x in classes])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.classes)))
        random.shuffle(self.class_range)

        assert batch_size < self.min_class_size, 'batch_size should be at least {}'.format(self.min_class_size)

    def __iter__(self):
        batches = []
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            batches.append(np.random.choice(self.classes[batch_class], self.batch_size))
        return iter(batches)


def main():
    # Code about
    _train_dataset = DS(train_dataset, train_dataset.num_classes)
    _batch_sampler = BatchSampler(_train_dataset.classes(), batch_size=args.batch_size)
    _train_loader = DataLoader(dataset=_train_dataset, batch_sampler=_batch_sampler)
    labels = []
    for i, (inputs, _labels) in enumerate(_train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    # Samplers are initialized to None and train_sampler will be replaced
    train_sampler, val_sampler = None, None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    main()
Run Code Online (Sandbox Code Playgroud)

打印:Length of traversed unique labels: 100.

然而,self.indices在for循环中创建需要花费大量时间。有没有更有效的方法来构建这个采样器?

编辑:产量实施

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
from tqdm import tqdm
import os.path


class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data
        self.data_len = len(data)

        indices = [[] for _ in range(num_classes)]

        for i, (_, class_label) in tqdm(enumerate(data), total=len(data), miniters=1,
                                        desc='Building class indices dataset..'):
            indices[class_label].append(i)

        self.indices = indices

    def per_class_sample_indices(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.data_len


class BatchSampler:
    def __init__(self, per_class_sample_indices, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.per_class_sample_indices = per_class_sample_indices
        self.n_batches = sum([len(x) for x in per_class_sample_indices]) // batch_size
        self.min_class_size = min([len(x) for x in per_class_sample_indices])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.per_class_sample_indices)))
        random.shuffle(self.class_range)

    def __iter__(self):
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            if self.batch_size <= len(self.per_class_sample_indices[batch_class]):
                batch = np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size)
                # batches.append(np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size))
            else:
                batch = self.per_class_sample_indices[batch_class]
            yield batch

    def n_batches(self):
        return self.n_batches


def main():
    file_path = 'a_file_path'
    file_name = 'per_class_sample_indices.pt'
    if not os.path.exists(os.path.join(file_path, file_name)):
        print('File: {} does not exists. Create it.'.format(file_name))
        per_class_sample_indices = DS(train_dataset, num_classes).per_class_sample_indices()
        torch.save(per_class_sample_indices, os.path.join(file_path, file_name))
    else:
        per_class_sample_indices = torch.load(os.path.join(file_path, file_name))
        print('File: {} exists. Do not create it.'.format(file_name))

    batch_sampler = BatchSampler(per_class_sample_indices,
                                 batch_size=args.batch_size)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        # batch_size=args.batch_size,
        # shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        # sampler=train_sampler,
        batch_sampler=batch_sampler
    )

    # We do not use sampler for the validation
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True, sampler=None)

    labels = []
    for i, (inputs, _labels) in enumerate(train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        num_classes = len(train_dataset.classes)

    main()
Run Code Online (Sandbox Code Playgroud)

可以在此处找到类似的帖子,但在 TensorFlow 中

Sha*_*hai 2

您应该batch_samplerDataLoader.