存储多维数组/张量的最佳方法

teu*_*cer 4 scala

我正在尝试创建一个张量(可以被设想为多维数组)包scala.到目前为止,我将数据存储在1D中Vector并进行索引算法.

但切片和子阵列并不容易获得.人们需要做很多算术才能将多维索引转换为1D索引.

有没有存储多维数组的最佳方法?如果没有,即一维数组是最好的解决方案,那么如何最佳地对数组进行切片(一些具体的代码对我有帮助)?

Rex*_*err 5

回答这个问题的关键是:什么时候指针间接比算术更快?答案几乎从来没有.对于2D,有序遍历的速度大致相同,事情变得更糟:

2D random access
  Array of Arrays - 600 M / second
  Multiplication - 1.1 G / second

3D in-order
  Array of Array of Arrays - 2.4G / second
  Multiplication - 2.8 G / second

(etc.)
Run Code Online (Sandbox Code Playgroud)

所以你最好只做数学.

现在的问题是如何进行切片.最初,如果您的尺寸为n1,n2,n3,...以及i1,i2,i3,...的索引,则计算到数组的偏移量

i = i1 + n1*(i2 + n2*(i3 + ... ))
Run Code Online (Sandbox Code Playgroud)

通常i1选择最后(最里面)维度(但通常​​它应该是最内层循环中最常见的维度).也就是说,如果它是一个(...)数组的数组,你可以将其索引为a(...)(i3)(i2)(i1).

现在假设您要切片.首先,您可以为每个索引提供偏移量o1,o2,o3:

i = (i1 + o1) + n1*((i2 + o2) + n2*((i3 + o3) + ...))
Run Code Online (Sandbox Code Playgroud)

然后你会有一个较短的范围(让我们称之为m1,m2,m3,...).

最后,如果你完全消除一个维度 - 比方说,那就是说m2 == 1,这意味着i2 == 0,你只需简化公式:

i = (i1 + o1 + n1*o2) + (n1+n2)*((i3 + o3) + ... ))
Run Code Online (Sandbox Code Playgroud)

我将把它作为练习留给读者来弄清楚如何一般地做到这一点,但请注意我们可以存储新的常量o1 + n1*o21,n1+n2因此我们不需要在切片上继续进行数学运算.

最后,如果您允许任意维度,您只需将该数学放入while循环中.不可否认,这确实会让它减慢一点,但你仍然至少可以像使用指针取消引用一样好(几乎在所有情况下).