PyTorch用于什么变换?

car*_*a88 7 transformation image input pytorch

我是Pytorch的新手,并不是CNN的专家.我已经用他们提供Tutorial Pytorch的教程做了一个成功的分类,但我真的不明白我在加载数据时正在做什么.

他们为训练做了一些数据增加和规范化,但是当我尝试修改参数时,代码不起作用.

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
Run Code Online (Sandbox Code Playgroud)

我是否在扩展我的训练数据集?我没有看到数据扩充.

为什么我修改transforms.RandomResizedCrop(224)的值,数据加载停止工作?

我是否还需要转换测试数据集?

我对他们所做的数据转换有点困惑.

lay*_*yog 17

transforms.Compose只是俱乐部提供给它的所有变革.因此,它中的所有变换transforms.Compose都逐个应用于输入.

火车改造

  1. transforms.RandomResizedCrop(224):这(224, 224)将从您的输入图像中随机提取一个大小的补丁.因此,它可能会选择这条路径来自topleft,bottomright或其间的任何位置.因此,您正在进行此部分的数据扩充.此外,更改此值对于模型中完全连接的图层不会很好,因此不建议更改此值.
  2. transforms.RandomHorizontalFlip():一旦我们得到尺寸图像(224, 224),我们就可以选择翻转它.这是数据扩充的另一部分.
  3. transforms.ToTensor():这只是将输入图像转换为PyTorch张量.
  4. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]):这只是输入数据缩放,并且必须为数据集预先计算这些值(mean和std).也不建议更改这些值.

验证转换

  1. transforms.Resize(256):首先将输入图像的大小调整为大小 (256, 256)
  2. transforms.CentreCrop(224):裁剪形状图像的中心部分 (224, 224)

休息和火车一样

PS:您可以在官方文档中阅读有关这些转换的更多信息