小编Ami*_*iri的帖子

pytorch 数据集中每个类的实例数

我正在尝试使用 PyTorch 制作一个简单的图像分类器。这是我将数据加载到数据集和 dataLoader 中的方式:

batch_size = 64
validation_split = 0.2
data_dir = PROJECT_PATH+"/categorized_products"
transform = transforms.Compose([transforms.Grayscale(), CustomToTensor()])

dataset = ImageFolder(data_dir, transform=transform)

indices = list(range(len(dataset)))

train_indices = indices[:int(len(indices)*0.8)] 
test_indices = indices[int(len(indices)*0.8):]

train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)
Run Code Online (Sandbox Code Playgroud)

我想分别打印出训练和测试数据中每个班级的图像数量,如下所示:

在火车数据中:

  • 鞋子:20
  • 衬衫:14

在测试数据中:

  • 鞋子:4
  • 衬衫:3

我试过这个:

from collections import Counter
print(dict(Counter(sample_tup[1] for sample_tup in dataset.imgs)))
Run Code Online (Sandbox Code Playgroud)

但我收到了这个错误:

AttributeError: 'MyDataset' object has no attribute 'img'
Run Code Online (Sandbox Code Playgroud)

python torch pytorch dataloader

8
推荐指数
1
解决办法
3079
查看次数

标签 统计

dataloader ×1

python ×1

pytorch ×1

torch ×1