如何从 torchvision.datasets.ImageFolder 获取 n 个图像

use*_*821 6 pytorch

我目前正在使用 PyTorch 进行 CNN 实验,我希望模型完成的任务是对图像进行分类。

我知道使用torchvision.datasets.ImageFolder可以帮助根据每个子文件夹的名称作为标签从我的训练文件夹加载所有图像。

我计划只从ImageFolder随机中获取 n 个图像,但据我所知,没有机制可以ImageFolder随机加载n 个图像,其中是从 1 到所有可用图像之间的任何数字。n

我怎样才能做到这一点?谢谢你的帮助

jod*_*dag 9

您可以ImageFolder使用 PyTorch 的Subset类创建一个子集。如果您愿意,我们可以使用 numpy 或其他方式生成随机索引。

dataset = torchvision.datasets.ImageFolder(...)
dataset_subset = torch.utils.data.Subset(dataset, numpy.random.choice(len(dataset), n, replace=False))
Run Code Online (Sandbox Code Playgroud)