填补numpy数组中的空白

Jos*_*ose 12 python interpolation numpy matplotlib scipy

我只想用最简单的术语插入3D数据集.线性插值,最近邻,所有这些就足够了(这是从一些算法开始,所以不需要准确的估计).

在新的scipy版本中,像griddata这样的东西会很有用,但是目前我只有scipy 0.8.所以我有一个"立方体"(data[:,:,:],(NixNjxNk))数组,以及一个相同大小的标志(flags[:,:,:,],TrueFalse)数组.我想插入数据元素的数据,其中flag的对应元素为False,例如使用数据中最近的有效数据点,或者"close by"点的某种线性组合.

数据集中至少有两个维度可能存在较大的间隙.除了使用kdtrees或类似编码完整的最近邻算法之外,我无法真正找到通用的N维最近邻插值器.

Juh*_*uh_ 29

使用scipy.ndimage,你的问题可以通过2行最近邻插值来解决:

from scipy import ndimage as nd

indices = nd.distance_transform_edt(invalid_cell_mask, return_distances=False, return_indices=True)
data = data[tuple(ind)]
Run Code Online (Sandbox Code Playgroud)

现在,以函数的形式:

import numpy as np
from scipy import ndimage as nd

def fill(data, invalid=None):
    """
    Replace the value of invalid 'data' cells (indicated by 'invalid') 
    by the value of the nearest valid data cell

    Input:
        data:    numpy array of any dimension
        invalid: a binary array of same shape as 'data'. 
                 data value are replaced where invalid is True
                 If None (default), use: invalid  = np.isnan(data)

    Output: 
        Return a filled array. 
    """    
    if invalid is None: invalid = np.isnan(data)

    ind = nd.distance_transform_edt(invalid, 
                                    return_distances=False, 
                                    return_indices=True)
    return data[tuple(ind)]
Run Code Online (Sandbox Code Playgroud)

使用范围:

def test_fill(s,d):
     # s is size of one dimension, d is the number of dimension
    data = np.arange(s**d).reshape((s,)*d)
    seed = np.zeros(data.shape,dtype=bool)
    seed.flat[np.random.randint(0,seed.size,int(data.size/20**d))] = True

    return fill(data,-seed), seed

import matplotlib.pyplot as plt
data,seed  = test_fill(500,2)
data[nd.binary_dilation(seed,iterations=2)] = 0   # draw (dilated) seeds in black
plt.imshow(np.mod(data,42))                       # show cluster
Run Code Online (Sandbox Code Playgroud)

结果: 在此输入图像描述

  • 哇,我不知道`distance_transform_edt`.这是一个非常有用的功能. (2认同)

Pau*_*aul 14

您可以设置晶体生长式算法,沿每个轴交替移动视图,仅替换标记为False但具有True邻居的数据.这给出了一个"最接近邻居"的结果(但不是欧几里德或曼哈顿距离 - 我认为如果你计算像素,它可能是最近邻,计算所有连接像素与公共角点)这对于NumPy应该是相当有效的因为它只迭代轴和收敛迭代,而不是数据的小片.

原油,快速而稳定.我认为这就是你所追求的:

import numpy as np
# -- setup --
shape = (10,10,10)
dim = len(shape)
data = np.random.random(shape)
flag = np.zeros(shape, dtype=bool)
t_ct = int(data.size/5)
flag.flat[np.random.randint(0, flag.size, t_ct)] = True
# True flags the data
# -- end setup --

slcs = [slice(None)]*dim

while np.any(~flag): # as long as there are any False's in flag
    for i in range(dim): # do each axis
        # make slices to shift view one element along the axis
        slcs1 = slcs[:]
        slcs2 = slcs[:]
        slcs1[i] = slice(0, -1)
        slcs2[i] = slice(1, None)

        # replace from the right
        repmask = np.logical_and(~flag[slcs1], flag[slcs2])
        data[slcs1][repmask] = data[slcs2][repmask]
        flag[slcs1][repmask] = True

        # replace from the left
        repmask = np.logical_and(~flag[slcs2], flag[slcs1])
        data[slcs2][repmask] = data[slcs1][repmask]
        flag[slcs2][repmask] = True
Run Code Online (Sandbox Code Playgroud)

为了更好地衡量,这里是由最初标记的数据播种的区域的可视化(2D)True.

在此输入图像描述