pytorch:获取给定 ImageFolder 数据集的类数

Tom*_*ale 2 python machine-learning deep-learning pytorch

如果我有一个像这样的数据集:

image_datasets['train'] = datasets.ImageFolder(train_dir, transform=train_transforms)
Run Code Online (Sandbox Code Playgroud)

如何以编程方式确定数据集中的类或唯一标签的数量?

Pet*_*ter 6

如果您的数据类型是张量,那么您可以使用:

import torch  
n_classes = len(torch.unique(Your_Target_Vector))
Run Code Online (Sandbox Code Playgroud)