如何使用带有百分比分割的random_split(输入长度​​之和不等于输入数据集的长度)

Pro*_*o Q 8 dataset pytorch

我尝试使用torch.utils.data.random_split如下:

import torch
from torch.utils.data import DataLoader, random_split

list_dataset = [1,2,3,4,5,6,7,8,9,10]
dataset = DataLoader(list_dataset, batch_size=1, shuffle=False)

random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(123))
Run Code Online (Sandbox Code Playgroud)

但是,当我尝试这个时,我收到了错误raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

我查看了文档,似乎我应该能够传递总和为 1 的小数,但显然它不起作用。

我也在谷歌上搜索了这个错误,最接近的事情就是这个问题

我究竟做错了什么?

Pro*_*o Q 9

您可能使用的是旧版本的 PyTorch,例如 Pytorch 1.10,它没有此功能

要在旧版本中复制此功能,您只需复制新版本的源代码即可:

import math
from torch import default_generator, randperm
from torch._utils import _accumulate
from torch.utils.data.dataset import Subset

def random_split(dataset, lengths,
                 generator=default_generator):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    If a list of fractions that sum up to 1 is given,
    the lengths will be computed automatically as
    floor(frac * len(dataset)) for each fraction provided.

    After computing the lengths, if there are any remainders, 1 count will be
    distributed in round-robin fashion to the lengths
    until there are no remainders left.

    Optionally fix the generator for reproducible results, e.g.:

    >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
    >>> random_split(range(30), [0.3, 0.3, 0.4], generator=torch.Generator(
    ...   ).manual_seed(42))

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths or fractions of splits to be produced
        generator (Generator): Generator used for the random permutation.
    """
    if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
        subset_lengths: List[int] = []
        for i, frac in enumerate(lengths):
            if frac < 0 or frac > 1:
                raise ValueError(f"Fraction at index {i} is not between 0 and 1")
            n_items_in_split = int(
                math.floor(len(dataset) * frac)  # type: ignore[arg-type]
            )
            subset_lengths.append(n_items_in_split)
        remainder = len(dataset) - sum(subset_lengths)  # type: ignore[arg-type]
        # add 1 to all the lengths in round-robin fashion until the remainder is 0
        for i in range(remainder):
            idx_to_add_at = i % len(subset_lengths)
            subset_lengths[idx_to_add_at] += 1
        lengths = subset_lengths
        for i, length in enumerate(lengths):
            if length == 0:
                warnings.warn(f"Length of split at index {i} is 0. "
                              f"This might result in an empty dataset.")

    # Cannot verify that dataset is Sized
    if sum(lengths) != len(dataset):    # type: ignore[arg-type]
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths), generator=generator).tolist()  # type: ignore[call-overload]
    return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
Run Code Online (Sandbox Code Playgroud)

  • 据我所知,这个直到 1.13 才被引入。如果我错了请纠正我。 (2认同)