在Python中计算图像数据集通道明智的平均值和标准偏差的最快方法

Bru*_*ein 10 python opencv computer-vision

我有一个巨大的图像数据集,不适合内存.我想计算meanstandard deviation从磁盘加载图像.

我目前正在尝试使用维基百科上的这个算法.

# for a new value newValue, compute the new count, new mean, the new M2.
# mean accumulates the mean of the entire dataset
# M2 aggregates the squared distance from the mean
# count aggregates the amount of samples seen so far
def update(existingAggregate, newValue):
    (count, mean, M2) = existingAggregate
    count = count + 1 
    delta = newValue - mean
    mean = mean + delta / count
    delta2 = newValue - mean
    M2 = M2 + delta * delta2

    return existingAggregate

# retrieve the mean and variance from an aggregate
def finalize(existingAggregate):
    (count, mean, M2) = existingAggregate
    (mean, variance) = (mean, M2/(count - 1)) 
    if count < 2:
        return float('nan')
    else:
        return (mean, variance)
Run Code Online (Sandbox Code Playgroud)

这是我目前的实现(仅针对红色通道计算):

count = 0
mean = 0
delta = 0
delta2 = 0
M2 = 0
for i, file in enumerate(tqdm(first)):
    image = cv2.imread(file)
    for i in range(224):
        for j in range(224):
            r, g, b = image[i, j, :]
            newValue = r
            count = count + 1
            delta = newValue - mean
            mean = mean + delta / count
            delta2 = newValue - mean
            M2 = M2 + delta * delta2

print('first mean', mean)
print('first std', np.sqrt(M2 / (count - 1)))
Run Code Online (Sandbox Code Playgroud)

这个实现在我尝试的数据集的子集上足够接近.

问题是它非常慢,因此不可行.

  • 有这样做的标准方法吗?

  • 我如何调整这个以获得更快的结果或计算所有数据集的RGB均值和标准偏差,而无需在内存中同时以合理的速度加载它们?

Yo *_*iao 7

由于这是一项繁重的数值任务(围绕矩阵或张量进行大量迭代),因此我总是建议使用擅长于此的库:numpy。

正确安装的 numpy 应该能够利用底层 BLAS(基本线性代数子例程)例程,这些例程针对从内存层次结构的角度操作浮点数组进行了优化。

imread 应该已经给你 numpy 数组。您可以通过以下方式获得红色通道图像的重塑一维数组

import numpy as np
val = np.reshape(image[:,:,0], -1)
Run Code Online (Sandbox Code Playgroud)

这样的平均值

np.mean(val)
Run Code Online (Sandbox Code Playgroud)

和标准偏差

np.std(val)
Run Code Online (Sandbox Code Playgroud)

这样就可以去掉两层python循环:

count = 0
mean = 0
delta = 0
delta2 = 0
M2 = 0
for i, file in enumerate(tqdm(first)):
    image = cv2.imread(file)
        val = np.reshape(image[:,:,0], -1)
        img_mean = np.mean(val)
        img_std = np.std(val)
        ...
Run Code Online (Sandbox Code Playgroud)

其余的增量更新应该很简单。

一旦你这样做了,瓶颈将成为图像加载速度,这受到磁盘读取操作性能的限制。在这方面,根据我之前的经验,我怀疑按照其他人的建议使用多线程会很有帮助。

  • @Bruno Klein 您上面引用的方法是一种数值稳定的方法,非常棒。但是,如果它不会溢出,您也可以选择更简单的方法:每个图像均值的数组总和的均值是总体均值。std 是一个棘手的问题。如果您不关心世界级的实现,可能会为每个图像存储平方和 (sos) 并逐步求和。完成后,计算 `sqrt(sum_sos/N - mean**2)`(你会想知道 `N` 是什么)。请注意,如果 `sum_sos` 太接近于 `mean**2`,您可能需要注意数值问题。 (2认同)

And*_*dov 5

您还可以使用 opencv 的方法meanstddev

\n\n
cv2.meanStdDev(src[, mean[, stddev[, mask]]]) \xe2\x86\x92 mean, stddev\n
Run Code Online (Sandbox Code Playgroud)\n