PyTorch数据集类的子类找不到数据集文件

san*_*oft 5 python constructor subclass python-3.x pytorch

我正在尝试创建 PyTorch MNIST 数据集类的子类,我将其称为 CustomMNISTDataset,如下所示:

import torchvision.datasets as datasets

class CustomMNISTDataset(datasets.MNIST):

    def __init__(self, root='/home/psando'):
        super().__init__(root=root,
                         download=False)
Run Code Online (Sandbox Code Playgroud)

但是当我执行时:

dataset = CustomMNISTDataset()
Run Code Online (Sandbox Code Playgroud)

它失败并出现错误:“RuntimeError:未找到数据集。您可以使用 download=True 来下载它”。

但是,当我在同一文件中运行以下命令时:

dataset = datasets.MNIST(root='/home/psando', download=False)
print(len(dataset))
Run Code Online (Sandbox Code Playgroud)

正如预期的那样,它成功并打印“60000”。

既然是CustomMNISTDataset子类,datasets.MNIST 为什么行为不同?我已经验证路径“/home/psando”包含 MNIST 目录以及原始子目录和已处理子目录(否则,显式调用构造函数datasets.MNIST()将会失败)。super().__init__()当前的行为意味着对inside的调用CustomMNISTDataset没有调用构造函数,datasets.MNIST这很奇怪!

其他细节:我正在使用 Python 3.6.8torch==1.6.0torchvision==0.7.0。任何帮助,将不胜感激!

小智 3

这需要一些源码挖掘,但你的问题是这个函数。数据集的路径取决于类的名称,因此当您子类化时MNIST根文件夹将更改为/home/psando/CustomMNISTDataset

所以如果你重命名/home/psando/MNIST它就/home/psando/CustomMNISTDataset可以了。