Pet*_*les 1 python artificial-intelligence python-3.x deep-learning pytorch
我正在使用 pytorch 和数据集时尚 MNIST,但我不知道如何评估该数据集的平均值和标准差。这是我的代码:
import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((mean), (std))])
batch_size = 32
train_loader = torch.utils.data.DataLoader(datasets.MNIST(
'../data', train=True, download=True, transform=transform)
, batch_size=batch_size, shuffle=True)
Run Code Online (Sandbox Code Playgroud)
请问你能帮帮我吗 ?
非常感谢 !
小智 5
用它来计算平均值和标准差
loader = data.DataLoader(dataset,
batch_size=10,
num_workers=0,
shuffle=False)
mean = 0.
std = 0.
for images, _ in loader:
batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
images = images.view(batch_samples, images.size(1), -1)
mean += images.mean(2).sum(0)
std += images.std(2).sum(0)
mean /= len(loader.dataset)
std /= len(loader.dataset)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2791 次 |
| 最近记录: |