我尝试使用torch.utils.data.random_split如下:
import torch
from torch.utils.data import DataLoader, random_split
list_dataset = [1,2,3,4,5,6,7,8,9,10]
dataset = DataLoader(list_dataset, batch_size=1, shuffle=False)
random_split(dataset, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(123))
Run Code Online (Sandbox Code Playgroud)
但是,当我尝试这个时,我收到了错误raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
我查看了文档,似乎我应该能够传递总和为 1 的小数,但显然它不起作用。
我也在谷歌上搜索了这个错误,最接近的事情就是这个问题。
我究竟做错了什么?
您可能使用的是旧版本的 PyTorch,例如 Pytorch 1.10,它没有此功能。
要在旧版本中复制此功能,您只需复制新版本的源代码即可:
import math
from torch import default_generator, randperm
from torch._utils import _accumulate
from torch.utils.data.dataset import Subset
def random_split(dataset, lengths,
generator=default_generator):
r"""
Randomly split a dataset into non-overlapping new datasets of given lengths.
If a list of fractions that sum up to 1 is given,
the lengths will be computed automatically as
floor(frac * len(dataset)) for each fraction provided.
After computing the lengths, if there are any remainders, 1 count will be
distributed in round-robin fashion to the lengths
until there are no remainders left.
Optionally fix the generator for reproducible results, e.g.:
>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=torch.Generator(
... ).manual_seed(42))
Args:
dataset (Dataset): Dataset to be split
lengths (sequence): lengths or fractions of splits to be produced
generator (Generator): Generator used for the random permutation.
"""
if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
subset_lengths: List[int] = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
n_items_in_split = int(
math.floor(len(dataset) * frac) # type: ignore[arg-type]
)
subset_lengths.append(n_items_in_split)
remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type]
# add 1 to all the lengths in round-robin fashion until the remainder is 0
for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")
# Cannot verify that dataset is Sized
if sum(lengths) != len(dataset): # type: ignore[arg-type]
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[call-overload]
return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
8078 次 |
| 最近记录: |