Leo*_*ewk 4 python machine-learning computer-vision pytorch
我正在 Kaggle 上进行仙人掌图像竞赛,我正在尝试将 PyTorch 数据加载器用于我的 CNN。但是,我遇到了无法为训练集设置标签的问题。训练集图像在文件夹中给出,标签在 csv 文件中。这是我的代码。
train = torchvision.datasets.ImageFolder(root='../input/train',
transform=transform)
train.targets = torch.from_numpy(df['has_cactus'].values)
train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, num_workers=2)
for i, data in enumerate(train_loader, 0):
print(data[1])
Run Code Online (Sandbox Code Playgroud)
此代码输出全为零的批处理张量,这显然是不正确的,因为绝大多数标签(如果您要查看数据帧)都是 1。我相信这是将标签分配给“train.targets”的问题。如果在分配其他标签之前打印“train.targets”,它会返回一个全为零的张量,这与我得到的错误结果一致。我该如何解决这个问题?
小智 7
我通常继承内置的 DataSet 类,如下所示:
from torch.utils.data import DataLoader
class DataSet:
def __init__(self, root):
"""Init function should not do any heavy lifting, but
must initialize how many items are available in this data set.
"""
self.ROOT = root
self.images = read_images(root + "/images")
self.labels = read_labels(root + "/labels")
def __len__(self):
"""return number of points in our dataset"""
return len(self.images)
def __getitem__(self, idx):
""" Here we have to return the item requested by `idx`
The PyTorch DataLoader class will use this method to make an iterable for
our training or validation loop.
"""
img = images[idx]
label = labels[idx]
return img, label
Run Code Online (Sandbox Code Playgroud)
现在,您可以创建此类的一个实例,
ds = Dataset('../input/train')
Run Code Online (Sandbox Code Playgroud)
现在,您可以实例化 DataLoader:
dl = DataLoader(ds, batch_size=TRAIN_BATCH_SIZE, shuffle=False, num_workers=4, drop_last=True)
Run Code Online (Sandbox Code Playgroud)
这将创建您可以访问的批量数据:
for image, label in dl:
print(label)
Run Code Online (Sandbox Code Playgroud)