pyTorch 中数据集和 TensorDataset 之间的区别

מור*_*ניק 16 deep-learning pytorch

“torch.utils.data.TensorDataset”和“torch.utils.data.Dataset”之间有什么区别 - 文档对此并不清楚,我在谷歌上找不到任何答案。

OSa*_*inz 17

该类Dataset是一个抽象类,用于定义新类型的(海关)数据集。相反,它TensorDataset是一个随时可用的类,将您的数据表示为张量列表。

您可以通过以下方式定义自定义数据集:

class CustomDataset(torch.utils.data.Dataset):

  def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)
    # Your code

    self.instances = your_data

  def __getitem__(self, idx):
    return self.instances[idx] # In case you stored your data on a list called instances

  def __len__(self):
    return len(self.instances)
Run Code Online (Sandbox Code Playgroud)

如果您只想创建一个包含输入特征和标签张量的数据集,则TensorDataset直接使用:

dataset = TensorDataset(input_features, labels)
Run Code Online (Sandbox Code Playgroud)

请注意,input_featureslabels必须匹配第一个维度的长度。