Numpyufunc
有一种reduceat
方法可以在数组中的连续分区上运行它们。所以不要写:
import numpy as np
a = np.array([4, 0, 6, 8, 0, 9, 8, 5, 4, 9])
split_at = [4, 5]
maxima = [max(subarray for subarray in np.split(a, split_at)]
Run Code Online (Sandbox Code Playgroud)
我可以写:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
Run Code Online (Sandbox Code Playgroud)
a[0:4]
两者都会返回 slices 、a[4:5]
、a[5:10]
、being中的最大值[8, 0, 9]
。
我想要一个类似的函数来执行argmax
,请注意,我只希望每个分区中有一个最大索引:[3, 4, 5]
使用上面的a
和split_at
(尽管索引 5 和 9 都获得了最后一组中的最大值),如由
np.hstack([0, split_at]) + [np.argmax(subarray) for subarray in np.split(a, split_at)]
Run Code Online (Sandbox Code Playgroud)
我将在下面发布一种可能的解决方案,但希望看到一种无需在组上创建索引即可进行矢量化的解决方案。
该解决方案涉及在组上构建索引([0, 0, 0, 0, 1, 2, 2, 2, 2, 2]
在上面的示例中)。
group_lengths = np.diff(np.hstack([0, split_at, len(a)]))
n_groups = len(group_lengths)
index = np.repeat(np.arange(n_groups), group_lengths)
Run Code Online (Sandbox Code Playgroud)
然后我们可以使用:
maxima = np.maximum.reduceat(a, np.hstack([0, split_at]))
all_argmax = np.flatnonzero(np.repeat(maxima, group_lengths) == a)
result = np.empty(len(group_lengths), dtype='i')
result[index[all_argmax[::-1]]] = all_argmax[::-1]
Run Code Online (Sandbox Code Playgroud)
进入。[3, 4, 5]
result
s[::-1]
确保我们得到每组中的第一个而不是最后一个 argmax。
这依赖于这样一个事实:花式赋值中的最后一个索引决定了分配的值,@seberg说不应该依赖这一事实(并且可以使用 实现更安全的替代方案result = all_argmax[np.unique(index[all_argmax], return_index=True)[1]]
,这涉及对元素进行排序len(maxima) ~ n_groups
)。
归档时间: |
|
查看次数: |
1611 次 |
最近记录: |