看来我迷失在可能愚蠢的东西中.我有一个n维numpy数组,我想将它与一个维度(可以改变!)的向量(1d数组)相乘.举个例子,假设我想将第二个数组乘以第一个数组的0轴的1d数组,我可以这样做:
a=np.arange(20).reshape((5,4))
b=np.ones(5)
c=a*b[:,np.newaxis]
Run Code Online (Sandbox Code Playgroud)
很简单,但我想将这个想法扩展到n维(对于a,而b总是1d)和任何轴.换句话说,我想知道如何在正确的位置生成np.newaxis的切片.假设a是3d并且我想沿轴= 1乘以,我想生成正确给出的切片:
c=a*b[np.newaxis,:,np.newaxis]
Run Code Online (Sandbox Code Playgroud)
即给出a(例如3)的维数,以及我想要乘以的轴(比如轴= 1),我如何生成并传递切片:
np.newaxis,:,np.newaxis
Run Code Online (Sandbox Code Playgroud)
谢谢.
Div*_*kar 10
解决方案代码
import numpy as np
# Given axis along which elementwise multiplication with broadcasting
# is to be performed
given_axis = 1
# Create an array which would be used to reshape 1D array, b to have
# singleton dimensions except for the given axis where we would put -1
# signifying to use the entire length of elements along that axis
dim_array = np.ones((1,a.ndim),int).ravel()
dim_array[given_axis] = -1
# Reshape b with dim_array and perform elementwise multiplication with
# broadcasting along the singleton dimensions for the final output
b_reshaped = b.reshape(dim_array)
mult_out = a*b_reshaped
Run Code Online (Sandbox Code Playgroud)
示例运行步骤的演示 -
In [149]: import numpy as np
In [150]: a = np.random.randint(0,9,(4,2,3))
In [151]: b = np.random.randint(0,9,(2,1)).ravel()
In [152]: whos
Variable Type Data/Info
-------------------------------
a ndarray 4x2x3: 24 elems, type `int32`, 96 bytes
b ndarray 2: 2 elems, type `int32`, 8 bytes
In [153]: given_axis = 1
Run Code Online (Sandbox Code Playgroud)
现在,我们想要执行元素乘法given axis = 1.让我们创建dim_array:
In [154]: dim_array = np.ones((1,a.ndim),int).ravel()
...: dim_array[given_axis] = -1
...:
In [155]: dim_array
Out[155]: array([ 1, -1, 1])
Run Code Online (Sandbox Code Playgroud)
最后,重塑b并执行元素乘法:
In [156]: b_reshaped = b.reshape(dim_array)
...: mult_out = a*b_reshaped
...:
Run Code Online (Sandbox Code Playgroud)
whos再次查看信息并特别注意b_reshaped&mult_out:
In [157]: whos
Variable Type Data/Info
---------------------------------
a ndarray 4x2x3: 24 elems, type `int32`, 96 bytes
b ndarray 2: 2 elems, type `int32`, 8 bytes
b_reshaped ndarray 1x2x1: 2 elems, type `int32`, 8 bytes
dim_array ndarray 3: 3 elems, type `int32`, 12 bytes
given_axis int 1
mult_out ndarray 4x2x3: 24 elems, type `int32`, 96 bytes
Run Code Online (Sandbox Code Playgroud)
避免复制数据,浪费资源!
利用转换和视图,而不是实际将数据复制 N 次到具有适当形状的新数组中(如现有答案所做的那样),可以提高内存效率。这是这样一个方法(基于@ShuxuanXU的代码):
def mult_along_axis(A, B, axis):
# ensure we're working with Numpy arrays
A = np.array(A)
B = np.array(B)
# shape check
if axis >= A.ndim:
raise AxisError(axis, A.ndim)
if A.shape[axis] != B.size:
raise ValueError(
"Length of 'A' along the given axis must be the same as B.size"
)
# np.broadcast_to puts the new axis as the last axis, so
# we swap the given axis with the last one, to determine the
# corresponding array shape. np.swapaxes only returns a view
# of the supplied array, so no data is copied unnecessarily.
shape = np.swapaxes(A, A.ndim-1, axis).shape
# Broadcast to an array with the shape as above. Again,
# no data is copied, we only get a new look at the existing data.
B_brc = np.broadcast_to(B, shape)
# Swap back the axes. As before, this only changes our "point of view".
B_brc = np.swapaxes(B_brc, A.ndim-1, axis)
return A * B_brc
Run Code Online (Sandbox Code Playgroud)
您可以构建一个切片对象,并在其中选择所需的维度:
import numpy as np
a = np.arange(18).reshape((3,2,3))
b = np.array([1,3])
ss = [None] * a.ndim
ss[1] = slice(None) # set the dimension along which to broadcast
print ss # [None, slice(None, None, None), None]
c = a*b[tuple(ss)] # convert to tuple to avoid FutureWarning from newer versions of Python
Run Code Online (Sandbox Code Playgroud)