用第二个数组的两个值替换数组的数据

use*_*754 8 python arrays indexing numpy

我有两个 numpy 数组“元素”和“节点”。我的目标是收集这些数组的一些数据。我需要用“节点”数组中包含的两个坐标替换最后两列的“元素”数据。这两个数组非常大,我必须自动化它。

这篇文章引用了一个旧帖子:用第二个数组的 2 个值替换数组的数据

不同之处在于数组非常大(元素:(3342558,5)和节点:(581589,4))并且以前的出路不起作用。

一个例子 :

    import numpy as np
    
    Elements = np.array([[1.,11.,14.],[2.,12.,13.]])
    
    nodes = np.array([[11.,0.,0.],[12.,1.,1.],[13.,2.,2.],[14.,3.,3.]])
    
    results = np.array([[1., 0., 0., 3., 3.],
    [2., 1., 1., 2., 2.]])
Run Code Online (Sandbox Code Playgroud)

之前hpaulj提出的出路

    e = Elements[:,1:].ravel().astype(int)
    n=nodes[:,0].astype(int)
    
    I, J = np.where(e==n[:,None])
    
    results = np.zeros((e.shape[0],2),nodes.dtype)
    results[J] = nodes[I,:1]
    results = results.reshape(2,4)
Run Code Online (Sandbox Code Playgroud)

但是对于庞大的数组,此脚本不起作用:
DepreciationWarning: elementwise comparison failed; this will raise an error in the future...

Div*_*kar 2

Elements游戏的大部分内容是从in 中找出相应的匹配索引nodes

方法#1

由于您似乎愿意转换为整数,因此我们假设我们可以将它们视为整数。这样,我们就可以使用基于array-assignment+mapping的方法,如下所示:

ar = Elements.astype(int)
a = ar[:,1:].ravel()
nd = nodes[:,0].astype(int)

n = a.max()+1
# for generalized case of neagtive ints in a or nodes having non-matching values:
# n = max(a.max()-min(0,a.min()), nd.max()-min(0,nd.min()))+1

lookup = np.empty(n, dtype=int)
lookup[nd] = np.arange(len(nd))
indices = lookup[a]

nc = (Elements.shape[1]-1)*(nodes.shape[1]-1) # 4 for given setup
out = np.concatenate((ar[:,0,None], nodes[indices,1:].reshape(-1,nc)),axis=1)
Run Code Online (Sandbox Code Playgroud)

方法#2

我们也可以用来np.searchsorted获取这些indices

对于根据第一个列和匹配情况排序行的节点,我们可以简单地使用:

indices = np.searchsorted(nd, a)
Run Code Online (Sandbox Code Playgroud)

对于不一定排序的情况和匹配的情况:

sidx = nd.argsort()
idx = np.searchsorted(nd, a, sorter=sidx)
indices = sidx[idx]
Run Code Online (Sandbox Code Playgroud)

对于不匹配的情况,请使用无效的 bool 数组:

invalid = idx==len(nd)
idx[invalid] = 0
indices = sidx[idx]
Run Code Online (Sandbox Code Playgroud)

方法#3

另一个带有concatenation+ sorting-

b = np.concatenate((nd,a))
sidx = b.argsort(kind='stable')

n = len(nd)
v = sidx<n
counts = np.diff(np.flatnonzero(np.r_[v,True]))
r = np.repeat(sidx[v], counts)

indices = np.empty(len(a), dtype=int)
indices[sidx[~v]-n] = r[sidx>=n]
Run Code Online (Sandbox Code Playgroud)

要检测不匹配的,请使用:

nd[indices] != a
Run Code Online (Sandbox Code Playgroud)

将这里的想法移植到numba

from numba import njit

def numba1(Elements, nodes):
    a = Elements[:,1:].ravel()
    nd = nodes[:,0]
    b = np.concatenate((nd,a))
    sidx = b.argsort(kind='stable')
    
    n = len(nodes)        
    ncols = Elements.shape[1]-1
    size = nodes.shape[1]-1        
    dt = np.result_type(Elements.dtype, nodes.dtype)
    nc = ncols*size
    
    out = np.empty((len(Elements),1+nc), dtype=dt)
    out[:,0] = Elements[:,0]
    return numba1_func(out, sidx, nodes, n, ncols, size)

@njit
def numba1_func(out, sidx, nodes, n, ncols, size):
    N = len(sidx)    
    for i in range(N):
        if sidx[i]<n:
            cur_id = sidx[i]
            continue
        else:
            idx = sidx[i]-n        
            row = idx//ncols
            col = idx-row*ncols        
            cc = col*size+1
            for ii in range(size):
                out[row, cc+ii] = nodes[cur_id,ii+1]
    return out
Run Code Online (Sandbox Code Playgroud)