寻找跨图像通道的均值和标准差 PyTorch

ch1*_*era 8 python mean standard-deviation deep-learning pytorch

假设我有一批尺寸为 (B x C x W x H) 的张量形式的图像,其中 B 是批量大小,C 是图像中的通道数,W 和 H 是宽度和高度图像分别。我希望使用该transforms.Normalize()函数根据跨 C 图像通道的数据集的均值和标准差对我的图像进行归一化,这意味着我想要 1 x C 形式的结果张量。 有没有直接的方法这个?

我试过了torch.view(C, -1).mean(1)torch.view(C, -1).std(1)但出现错误:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
Run Code Online (Sandbox Code Playgroud)

编辑

在研究了view()PyTorch 的工作原理后,我知道为什么我的方法不起作用;但是,我仍然无法弄清楚如何获得每个通道的平均值和标准偏差。

小智 9

请注意,增加的是方差,而不是标准差。请参阅此处的详细说明:https : //apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters

这是修改后的代码:

nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
    batch = batch_target[0]
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    var += batch.var(2).sum(0)

mean /= nimages
var /= nimages
std = torch.sqrt(var)

print(mean)
print(std)
Run Code Online (Sandbox Code Playgroud)


trs*_*chn 8

您只需要以正确的方式重新排列批量张量:从[B, C, W, H][B, C, W * H]

batch = batch.view(batch.size(0), batch.size(1), -1)
Run Code Online (Sandbox Code Playgroud)

以下是随机数据的完整使用示例:

代码:

import torch
from torch.utils.data import TensorDataset, DataLoader

data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)

nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    std += batch.std(2).sum(0)

# Final step
mean /= nimages
std /= nimages

print(mean)
print(std)
Run Code Online (Sandbox Code Playgroud)

输出:

tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])
Run Code Online (Sandbox Code Playgroud)