我有一个类型为f32(from ndarray::ArrayView2)的二维数组,我想找到每一行中最大值的索引,并将索引值放入另一个数组中。
Python 中的等价物类似于:
import numpy as np
for i in range (0, max_val, batch_size):
sims = xp.dot(batch, vectors.T)
# sims is the dot product of batch and vectors.T
# the shape is, for example, (1024, 10000)
best_rows[i: i+batch_size] = sims.argmax(axis = 1)
Run Code Online (Sandbox Code Playgroud)
在 Python 中,该函数.argmax非常快,但我在 Rust 中没有看到任何类似的函数。这样做的最快方法是什么?
小智 8
考虑一般Ord类型的简单情况:根据您是否知道这些值,答案会略有不同Copy,但代码如下:
fn position_max_copy<T: Ord + Copy>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by_key(|(_, &value)| value).map(|(idx, _)| idx)
}
fn position_max<T: Ord>(slice: &[T]) -> Option<usize> {
slice.iter().enumerate().max_by(|(_, value0), (_, value1)| value0.cmp(value1)).map(|(idx, _)| idx)
}
Run Code Online (Sandbox Code Playgroud)
基本思想是,我们将数组中的每个项目(实际上是一个切片 - 无论它是 Vec 还是数组还是更奇特的东西)与其索引配对,使用函数来查找std::iter::Iterator最大值value 仅根据值(而不是索引),然后仅返回索引。如果切片为空None则返回。根据文档,将返回最右边的索引;如果您需要最左边的,请执行rev() after enumerate()。
rev()、enumerate()、max_by_key()、 和max_by()记录在此处;slice::iter()已在此处记录(但作为 rust 开发人员,需要在没有文档的情况下将其列入您要回忆的事情的候选清单中);map记录在这里Option::map()(同上)。哦,但大多数时候您可以使用不需要它的版本(例如,如果您正在比较整数)。cmpOrd::cmpCopy
现在有一个问题:f32不是Ord因为 IEEE 浮点数的工作方式。大多数语言都忽略了这一点并且有微妙错误的算法。最流行的提供总顺序的板条箱Ord(通过声明所有 NaN 相等,并且大于所有数字)似乎是ordered-float。假设它正确实现,它应该是非常非常轻量级的。它确实会引入num_traits,但这是最流行的数字库的一部分,因此很可能已经被其他依赖项引入了。
ordered_float::OrderedFloat在这种情况下,您可以通过将(元组类型的“构造函数”)映射到切片 iter ( ) 上来使用它slice.iter().map(ordered_float::OrderedFloat)。由于您只想要最大元素的位置,因此无需随后提取 f32。
小智 4
@David A 的方法很酷,但正如前面提到的,有一个问题:f32& f64do not Implement Ord::cmp。(这确实是一个你知道的痛苦。)
有多种方法可以解决这个问题:您可以cmp自己实现,或者可以使用ordered-float等等。
就我而言,这是一个更大项目的一部分,我们在使用外部包时非常小心。此外,我很确定我们没有任何NaN价值观。因此,我更喜欢使用fold,如果你仔细查看max_by_key源代码,你会发现他们也一直在使用它。
for (i, row) in matrix.axis_iter(Axis(1)).enumerate() {
let (max_idx, max_val) =
row.iter()
.enumerate()
.fold((0, row[0]), |(idx_max, val_max), (idx, val)| {
if &val_max > val {
(idx_max, val_max)
} else {
(idx, *val)
}
});
}
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2211 次 |
| 最近记录: |