use*_*868 4 pytorch torchvision
Pytorch 使用以下值作为 cifar10 数据的平均值和标准差:transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
我需要理解计算背后的概念,因为这些数据是 3 通道图像,我不明白什么是相加的,什么是除什么的等等。另外,如果有人可以分享计算平均值和标准差的代码,将非常感激。
0.5 值只是三个通道(r、g、b)上cifar10平均值和标准值的近似值。cifar10 训练集的精确值为
0.49139968, 0.48215827 ,0.446531240.24703233 0.24348505 0.26158768您可以使用以下脚本计算这些:
import torch
import numpy
import torchvision.datasets as datasets
from torchvision import transforms
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
imgs = [item[0] for item in cifar_trainset] # item[0] and item[1] are image and its label
imgs = torch.stack(imgs, dim=0).numpy()
# calculate mean over each channel (r,g,b)
mean_r = imgs[:,0,:,:].mean()
mean_g = imgs[:,1,:,:].mean()
mean_b = imgs[:,2,:,:].mean()
print(mean_r,mean_g,mean_b)
# calculate std over each channel (r,g,b)
std_r = imgs[:,0,:,:].std()
std_g = imgs[:,1,:,:].std()
std_b = imgs[:,2,:,:].std()
print(std_r,std_g,std_b)
Run Code Online (Sandbox Code Playgroud)
此外,您可能会在此处和此处找到相同的平均值和标准值
替代方式
from torchvision import datasets
cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True )
data = cifar_trainset.data / 255 # data is numpy array
mean = data.mean(axis = (0,1,2))
std = data.std(axis = (0,1,2))
print(f"Mean : {mean} STD: {std}") #Mean : [0.491 0.482 0.446] STD: [0.247 0.243 0.261]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
8234 次 |
| 最近记录: |