pytorch collat​​e_fn 拒绝样本并产生另一个

Bri*_*nto 7 python-3.x pytorch dataloader

我已经建立了一个数据集,我正在对我正在加载的图像进行各种检查。然后我将此 DataSet 传递给 DataLoader。

在我的 DataSet 类中,如果图片未通过我的检查,我将样本作为 None 返回,并且我有一个自定义 collat​​e_fn 函数,该函数从检索到的批次中删除所有 None 并返回剩余的有效样本。

然而,此时返回的批次可能具有不同的大小。有没有办法告诉 collat​​e_fn 继续获取数据,直到批量大小达到一定长度?

class DataSet():
     def __init__(self, example):
          # initialise dataset
          # load csv file and image directory
          self.example = example
     def __getitem__(self,idx):
          # load one sample
          # if image is too dark return None
          # else 
          # return one image and its equivalent label

dataset = Dataset(csv_file='../', image_dir='../../')

dataloader = DataLoader(dataset , batch_size=4,
                        shuffle=True, num_workers=1, collate_fn = my_collate )

def my_collate(batch): # batch size 4 [{tensor image, tensor label},{},{},{}] could return something like G = [None, {},{},{}]
    batch = list(filter (lambda x:x is not None, batch)) # this gets rid of nones in batch. For example above it would result to G = [{},{},{}]
    # I want len(G) = 4
    # so how to sample another dataset entry?
    return torch.utils.data.dataloader.default_collate(batch) 
Run Code Online (Sandbox Code Playgroud)

小智 9

这对我有用,因为有时甚至那些随机值也不是。

def my_collate(batch):
    len_batch = len(batch)
    batch = list(filter(lambda x: x is not None, batch))

    if len_batch > len(batch):                
        db_len = len(dataset)
        diff = len_batch - len(batch)
        while diff != 0:
            a = dataset[np.random.randint(0, db_len)]
            if a is None:                
                continue
            batch.append(a)
            diff -= 1

    return torch.utils.data.dataloader.default_collate(batch)
Run Code Online (Sandbox Code Playgroud)


Bri*_*nto 8

有2个技巧可以用来解决问题,选择一种方法:

通过使用原始批次样品Fast 选项

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # if there are samples missing just use existing members, doesn't work if you reject every sample in a batch
        diff = len_batch - len(batch)
        for i in range(diff):
            batch = batch + batch[:diff]
    return torch.utils.data.dataloader.default_collate(batch)
Run Code Online (Sandbox Code Playgroud)

否则只需从数据集中随机加载另一个样本更好的选择

def my_collate(batch):
    len_batch = len(batch) # original batch length
    batch = list(filter (lambda x:x is not None, batch)) # filter out all the Nones
    if len_batch > len(batch): # source all the required samples from the original dataset at random
        diff = len_batch - len(batch)
        for i in range(diff):
            batch.append(dataset[np.random.randint(0, len(dataset))])

    return torch.utils.data.dataloader.default_collate(batch)
Run Code Online (Sandbox Code Playgroud)

  • 您将如何构造数据加载器 collat​​e_fn 参数以使数据集在范围内? (3认同)
  • 感谢您的代码!我认为“更好的选择”中也应该支持新样本也可能是“无”。所以我想应该有一个 while 循环之类的东西。 (3认同)

Ibr*_*zic 7

[编辑] 从下面截取的代码的更新版本可以在这里找到https://github.com/project-lighter/lighter/blob/main/lighter/utils/collat​​e.py

感谢 Brian Formento 提出问题并给出解决问题的想法。如前所述,用新示例替换坏示例的最佳选项有两个问题:

  1. 新采样的示例也可能已损坏;
  2. 该数据集不在范围内。

这是两个问题的解决方案 - 问题 1 通过递归调用解决,问题 2 通过创建整理函数的部分函数并将数据集固定到位来解决。

import random
import torch


def collate_fn_replace_corrupted(batch, dataset):
    """Collate function that allows to replace corrupted examples in the
    dataloader. It expect that the dataloader returns 'None' when that occurs.
    The 'None's in the batch are replaced with another examples sampled randomly.

    Args:
        batch (torch.Tensor): batch from the DataLoader.
        dataset (torch.utils.data.Dataset): dataset which the DataLoader is loading.
            Specify it with functools.partial and pass the resulting partial function that only
            requires 'batch' argument to DataLoader's 'collate_fn' option.

    Returns:
        torch.Tensor: batch with new examples instead of corrupted ones.
    """ 
    # Idea from https://stackoverflow.com/a/57882783

    original_batch_len = len(batch)
    # Filter out all the Nones (corrupted examples)
    batch = list(filter(lambda x: x is not None, batch))
    filtered_batch_len = len(batch)
    # Num of corrupted examples
    diff = original_batch_len - filtered_batch_len
    if diff > 0:
        # Replace corrupted examples with another examples randomly
        batch.extend([dataset[random.randint(0, len(dataset)-1)] for _ in range(diff)])
        # Recursive call to replace the replacements if they are corrupted
        return collate_fn_replace_corrupted(batch, dataset)
    # Finally, when the whole batch is fine, return it
    return torch.utils.data.dataloader.default_collate(batch)

Run Code Online (Sandbox Code Playgroud)

但是,您不能将其直接传递给,DataLoader因为整理函数应该只有一个参数 - batch。为了实现这一点,我们使用指定的数据集创建一个部分函数,​​并将该部分函数传递给DataLoader.

import functools
from torch.utils.data import DataLoader


collate_fn = functools.partial(collate_fn_replace_corrupted, dataset=dataset)
return DataLoader(dataset,
                  batch_size=batch_size,
                  num_workers=num_workers,
                  pin_memory=pin_memory,
                  collate_fn=collate_fn)
Run Code Online (Sandbox Code Playgroud)