我是pytorch的新手,想了解一些东西。
我正在按以下方式加载MNIST:
transform_train = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(size, interpolation=2),
# transforms.Grayscale(num_output_channels=1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize((mean), (std))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
Run Code Online (Sandbox Code Playgroud)
但是,当我探索数据集时,即trainloader.dataset.train_data[0],我得到的张量为[0,255],形状为(28,28)。
我想念什么?这是因为转换没有直接应用于数据加载器,而是仅在运行时?否则我该如何浏览我的数据?
调用的__getitem__方法时应用转换Dataset。例如,查看数据集类的__getitem__方法MNIST:https : //github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
Run Code Online (Sandbox Code Playgroud)
__getitem__当您MNIST为训练集索引实例时,将调用该方法,例如:
trainset[0]
Run Code Online (Sandbox Code Playgroud)
有关更多信息__getitem__:https : //docs.python.org/3.6/reference/datamodel.html#object。getitem
为什么原因Resize和RandomHorizontalFlip前应ToTensor是它们作用于PIL图像和所有数据集中在Pytorch一致性负载数据作为PIL Image第一秒。实际上,您可以在这里看到他们通过以下方式强制执行该行为:
img = Image.fromarray(img.numpy(), mode='L')
Run Code Online (Sandbox Code Playgroud)
一旦你有PIL Image相应的指数,在变换被施加
if self.transform is not None:
img = self.transform(img)
Run Code Online (Sandbox Code Playgroud)
ToTensor将PIL Imagea 转换为a torch.Tensor并Normalize减去平均值,然后除以您提供的标准差。
最终,一些转换将应用于
if self.target_transform is not None:
target = self.target_transform(target)
Run Code Online (Sandbox Code Playgroud)
最后,返回处理后的图像和处理后的标签。所有这些都在一个trainset[key]电话中发生。
import torch
from torchvision.transforms import *
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
transform_train = Compose([Resize(28, interpolation=2),
RandomHorizontalFlip(p=0.5),
ToTensor(),
Normalize([0.], [1.])])
trainset = MNIST(root='./data', train=True, download=True,
transform=transform_train)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())
Run Code Online (Sandbox Code Playgroud)
表演
(torch.Size([1, 28, 28]), tensor(0.), tensor(1.))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1537 次 |
| 最近记录: |