在numpy数组中获取项目的邻居

con*_*dor 3 python arrays numpy

我有一个看起来像这样的数组:

[['A0' 'B0' 'C0']
 ['A1' 'B1' 'C1']
 ['A2' 'B2' 'C2']]
Run Code Online (Sandbox Code Playgroud)

我想获得B1的邻居B0 , C1 , B2 , A1,以及其索引。

这是我想出的:

import numpy as np


arr = np.array([
    ['A0','B0','C0'],
    ['A1','B1','C1'],
    ['A2','B2','C2'],
])


def get_neighbor_indices(x,y):
    neighbors = []
    try:
        top = arr[y - 1, x]
        neighbors.append((top, (y - 1, x)))
    except IndexError:
        pass
    try:
        bottom = arr[y + 1, x]
        neighbors.append((bottom, (y + 1, x)))
    except IndexError:
        pass
    try:
        left = arr[y, x - 1]
        neighbors.append((left, (y, x - 1)))
    except IndexError:
        pass
    try:
        right = arr[y, x + 1]
        neighbors.append((right, (y, x + 1)))
    except IndexError:
        pass
    return neighbors
Run Code Online (Sandbox Code Playgroud)

这将返回一个元组列表(value, (y, x))

有没有更好的方法可以做到这一点而无需依赖try / except?

Mad*_*ist 5

您可以直接在numpy中执行此操作,而无需任何例外,因为您知道数组的大小。的近邻的指标x, y由下式给出

inds = np.array([[x, y]]) + np.array([[1, 0], [-1, 0], [0, 1], [0, -1]])
Run Code Online (Sandbox Code Playgroud)

您可以轻松制作一个掩码,以指示哪些索引有效:

valid = (inds[:, 0] >= 0) & (inds[:, 0] < arr.shape[0]) & \
        (inds[:, 1] >= 0) & (inds[:, 1] < arr.shape[1])
Run Code Online (Sandbox Code Playgroud)

现在提取所需的值:

inds = inds[valid, :]
vals = arr[inds[:, 0], inds[:, 1]]
Run Code Online (Sandbox Code Playgroud)

最简单的返回值是inds, vals,但是如果您坚持保留原始格式,则可以将其转换为

[v, tuple(i) for v, i in zip(vals, inds)]
Run Code Online (Sandbox Code Playgroud)

附录

您可以轻松地对此进行修改以在任意尺寸上工作:

def neighbors(arr, *pos):
    pos = np.array(pos).reshape(1, -1)
    offset = np.zeros((2 * pos.size, pos.size), dtype=np.int)
    offset[np.arange(0, offset.shape[0], 2), np.arange(offset.shape[1])] = 1
    offset[np.arange(1, offset.shape[0], 2), np.arange(offset.shape[1])] = -1
    inds = pos + offset
    valid = np.all(inds >= 0, axis=1) & np.all(inds < arr.shape, axis=1)
    inds = inds[valid, :]
    vals = arr[tuple(inds.T)]
    return vals, inds
Run Code Online (Sandbox Code Playgroud)

给定一个N维数组arr和N个元素pos,可以通过将每个维顺序设置为1或来创建偏移量-1valid通过一起广播indsarr.shape,以及np.all跨每个N大小的行调用而不是为每个维度手动进行操作,可以大大简化掩码的计算。最后,通过将每一列分配给一个单独的维度,转换tuple(inds.T)inds变成一个实际的花式索引。转置是必需的,因为数组在行上进行迭代(dim 0)。