Mon*_*all 7 numpy python-3.x numpy-ufunc
我正在研究cs231n,我很难理解这个索引是如何工作的.鉴于
x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [[[ 1.19034710e-01 -4.65005990e-01 8.93743168e-01 -9.78047129e-01
-8.88672957e-01 -4.66605091e-01]
[ -1.38617461e-03 -2.64569728e-01 -3.83712733e-01 -2.61360826e-01
8.07072009e-01 -5.47607277e-01]
[ -3.97087458e-01 -4.25187949e-02 2.57931759e-01 7.49565950e-01
1.37707667e+00 1.77392240e+00]]
[[ -1.20692745e+00 -8.28111550e-01 6.53041092e-01 -2.31247762e+00
-1.72370321e+00 2.44308033e+00]
[ -1.45191870e+00 -3.49328154e-01 6.15445782e-01 -2.84190582e-01
4.85997687e-02 4.81590106e-01]
[ -1.14828583e+00 -9.69055406e-01 -1.00773809e+00 3.63553835e-01
-1.28078363e+00 -2.54448436e+00]]]
Run Code Online (Sandbox Code Playgroud)
他们做的操作是
np.add.at(dW, x, dout)
x是二维数组.索引如何在这里工作?我浏览了np.ufunc.at
文档,但他们有一个简单的例子,有1d数组和常量:
np.add.at(a, [0, 1, 2, 2], 1)
Run Code Online (Sandbox Code Playgroud)
In [226]: x = [[0,4,1], [3,2,4]]
...: dW = np.zeros((5,6),int)
In [227]: np.add.at(dW,x,1)
In [228]: dW
Out[228]:
array([[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0]])
Run Code Online (Sandbox Code Playgroud)
这样x
就没有任何重复的条目,因此add.at
与使用+=
索引相同.同样,我们可以使用以下方法读取更改的值:
In [229]: dW[x[0], x[1]]
Out[229]: array([1, 1, 1])
Run Code Online (Sandbox Code Playgroud)
指数的工作方式相同,包括广播:
In [234]: dW[...]=0
In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
In [236]: dW
Out[236]:
array([[0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 2, 0],
[0, 0, 1, 0, 2, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]])
Run Code Online (Sandbox Code Playgroud)
broadcastable
对于索引,值必须是:
In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
...
In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
...
ValueError: array is not broadcastable to correct shape
In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])
In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])
In [118]: dW
Out[118]:
array([[ 0, 0, 0, 0, 0, 0],
[ 0, 0, 3, 0, 9, 0],
[ 0, 0, 4, 0, 11, 0],
[ 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0]])
Run Code Online (Sandbox Code Playgroud)
在这种情况下,索引定义(2,3)形状,因此(2,3),(3,),(2,1)和标量值起作用.(6,)没有.
在这种情况下,add.at
将(2,3)数组映射到(2,2)子阵列上dW
.
最近我也很难理解这行代码。希望我得到的可以帮助你,如果我错了,请纠正我。
这行代码中的三个数组如下:
x , whose shape is (N,T)
dW, ---(V,D)
dout ---(N,T,D)
Run Code Online (Sandbox Code Playgroud)
然后我们来看看我们想弄清楚发生了什么的行代码
np.add.at(dW, x, dout)
Run Code Online (Sandbox Code Playgroud)
如果你不想知道思考过程。上面的代码等价于:
for row in range(N):
for col in range(T):
dW[ x[row,col] , :] += dout[row,col, :]
Run Code Online (Sandbox Code Playgroud)
这是思考过程:
参考这个文档
https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ufunc.at.html
我们知道 x 是索引数组。所以关键是要理解dW[x]。这是使用另一个数组(x)索引数组(dW)的概念。如果你不熟悉这个概念,可以查看这个链接
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html
一般而言,使用索引数组时返回的是一个与索引数组形状相同的数组,但被索引的数组的类型和值。
dW[x] 将给我们一个数组,其形状为 (N,T,D),(N,T) 部分来自 x,(D) 来自 dW (V,D)。注意这里,x 的每个元素都在 [0, v) 的范围内。
让我们以一些数字作为具体例子
x: np.array([[0,0],[0,0]]) ---- (2,2) N=2, T=2
dW: np.array([[0,0],[2,2]]) ---- (2,2) V=2, D=2
dout: np.arange(1,9).reshape(2,2,2) ----(2,2,2) N=2, T=2, D=2
dW[x] should be [ [[0 0] #this comes from the dW's firt row
[0 0]]
[[0 0]
[0 0]] ]
Run Code Online (Sandbox Code Playgroud)
dW[x] add dout 表示添加 elemnet 项(这里是一些技巧,稍后会解释)
np.add.at(dW, x, dout) gives
[ [16 20]
[ 2 2] ]
Run Code Online (Sandbox Code Playgroud)
为什么?程序是:
它将 [1,2] 添加到 dW 的第一行,即 [0,0]。
为什么是第一排?因为x[0,0] = 0,表示dW的第一行,dW[0] = dW[0,:] = 第一行。
然后将 [3,4] 添加到 dW[0,0] 的第一行。[3,4]=dout[0,1,:]。[0,0]再次来自dW,x[0,1] = 0,仍然是dW[0]的第一行。
然后将 [5,6] 添加到 dW 的第一行。
然后将 [7,8] 添加到 dW 的第一行。
所以结果是 [1+3+5+7, 2+4+6+8] = [16,20]。因为我们没有接触dW的第二排。dW 的第二行保持不变。
诀窍是我们只对原行计数一次,可以认为没有缓冲区,每一步都在原处播放。