M.T*_*M.T 9 python numpy numpy-broadcasting
假设我想有大小的numpy的阵列(n,m),其中n是非常大的,但有很多重复,即.0:n1是相同的,n1:n2是相同的等(有n2%n1!=0,但不是规则的间隔).有没有办法只为每个重复项存储一组值,同时拥有整个数组的视图?
例:
unique_values = np.array([[1,1,1], [2,2,2] ,[3,3,3]]) #these are the values i want to store in memory
index_mapping = np.array([0,0,1,1,1,2,2]) # a mapping between index of array above, with array below
unique_values_view = np.array([[1,1,1],[1,1,1],[2,2,2],[2,2,2],[2,2,2], [3,3,3],[3,3,3]]) #this is how I want the view to look like for broadcasting reasons
Run Code Online (Sandbox Code Playgroud)
我计划将数组(视图)乘以其他一些大小的数组(1,m),并取这个产品的点积:
other_array1 = np.arange(unique_values.shape[1]).reshape(1,-1) # (1,m)
other_array2 = 2*np.ones((unique_values.shape[1],1)) # (m,1)
output = np.dot(unique_values_view * other_array1, other_array2).squeeze()
Run Code Online (Sandbox Code Playgroud)
输出是长度为1D的数组n.
根据您的示例,您可以简单地将通过计算的索引映射计算到最后:
output2 = np.dot(unique_values * other_array1, other_array2).squeeze()[index_mapping]
assert (output == output2).all()
Run Code Online (Sandbox Code Playgroud)
您的表达式有两个重要的优化:
other_array1先乘以other_array2,然后
乘以unique_values让我们应用这些优化:
>>> output_pp = (unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping]
# check for correctness
>>> (output == output_pp).all()
True
# and compare it to @Yakym Pirozhenko's approach
>>> from timeit import timeit
>>> print("yp:", timeit("np.dot(unique_values * other_array1, other_array2).squeeze()[index_mapping]", globals=globals()))
yp: 3.9105667411349714
>>> print("pp:", timeit("(unique_values @ (other_array1.ravel() * other_array2.ravel()))[index_mapping]", globals=globals()))
pp: 2.2684884609188884
Run Code Online (Sandbox Code Playgroud)
如果我们观察两件事,这些优化就很容易发现:
(1) 如果A是mxn- 矩阵并且b是n- 向量 那么
A * b == A @ diag(b)
A.T * b[:, None] == diag(b) @ A.T
Run Code Online (Sandbox Code Playgroud)
(2) ifA是一个mxn-matrix 并且I是一个-thenk整数向量
range(m)
A[I] == onehot(I) @ A
Run Code Online (Sandbox Code Playgroud)
onehot可以定义为
def onehot(I, m, dtype=int):
out = np.zeros((I.size, m), dtype=dtype)
out[np.arange(I.size), I] = 1
return out
Run Code Online (Sandbox Code Playgroud)
利用这些事实并缩写uv, im,oa1我们oa2可以写
uv[im] * oa1 @ oa2 == onehot(im) @ uv @ diag(oa1) @ oa2
Run Code Online (Sandbox Code Playgroud)
上述优化现在只是为这些矩阵乘法选择最佳顺序的问题,即
onehot(im) @ (uv @ (diag(oa1) @ oa2))
Run Code Online (Sandbox Code Playgroud)
向后使用(1)和(2),我们从本文的开头获得了优化的表达式。
| 归档时间: |
|
| 查看次数: |
248 次 |
| 最近记录: |