roe*_*and 5 python numpy automatic-differentiation
我正在尝试使用行为类似于NumPy数组的类来实现自动微分.它不是子类numpy.ndarray,但包含两个数组属性.一个用于值,一个用于雅可比矩阵.每个操作都被重载以对值和雅可比行为进行操作.但是,我无法使NumPy ufuncs(例如np.log)在我的自定义"数组"上工作.
我创建了以下最小示例,说明了该问题.Two应该是NumPy数组的辐射加固版本,它可以计算两次,并确保结果相同.
它必须是支持索引,元素对数和长度.就像平常一样ndarray.当使用调用时,元素对数工作正常x.cos(),但在调用时会执行意外操作np.cos(x).
from __future__ import print_function
import numpy as np
class Two(object):
def __init__(self, val1, val2):
print("init with", val1, val2)
assert np.array_equal(val1, val2)
self.val1 = val1
self.val2 = val2
def __getitem__(self, s):
print("getitem", s, "got", Two(self.val1[s], self.val2[s]))
return Two(self.val1[s], self.val2[s])
def __repr__(self):
return "<<{}, {}>>".format(self.val1, self.val2)
def log(self):
print("log", self)
return Two(np.log(self.val1), np.log(self.val2))
def __len__(self):
print("len", self, "=", self.val1.shape[0])
return self.val1.shape[0]
x = Two(np.array([1,2]).T, np.array([1,2]).T)
Run Code Online (Sandbox Code Playgroud)
索引按预期返回两个属性中的相关元素:
>>> print("First element in x:", x[0], "\n")
init with [1 2] [1 2]
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
First element in x: <<1, 1>>
Run Code Online (Sandbox Code Playgroud)
使用时调用元素对数可以正常工作x.cos():
>>> print("--- x.log() ---", x.log(), "\n")
log <<[1 2], [1 2]>>
init with [ 0. 0.69314] [ 0. 0.69314]
--- x.log() --- <<[ 0. 0.69314], [ 0. 0.69314]>>
Run Code Online (Sandbox Code Playgroud)
但是,np.log(x)不能按预期工作.它意识到对象有一个长度,所以它提取每个项目并在每个项目上取对数,然后返回一个两个对象的数组(dtype = object).
>>> print("--- np.log(x) with len ---", np.log(x), "\n") # WTF
len <<[1 2], [1 2]>> = 2
len <<[1 2], [1 2]>> = 2
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
init with 2 2
getitem 1 got <<2, 2>>
init with 2 2
len <<[1 2], [1 2]>> = 2
len <<[1 2], [1 2]>> = 2
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
init with 2 2
getitem 1 got <<2, 2>>
init with 2 2
len <<[1 2], [1 2]>> = 2
len <<[1 2], [1 2]>> = 2
init with 1 1
getitem 0 got <<1, 1>>
init with 1 1
init with 2 2
getitem 1 got <<2, 2>>
init with 2 2
log <<1, 1>>
init with 0.0 0.0
log <<2, 2>>
init with 0.693147 0.693147
--- np.log(x) with len --- [<<0.0, 0.0>> <<0.693147, 0.693147>>]
Run Code Online (Sandbox Code Playgroud)
如果Two没有长度方法,它可以正常工作:
>>> del Two.__len__
>>> print("--- np.log(x) without len ---", np.log(x), "\n")
log <<[1 2], [1 2]>>
init with [ 0. 0.69314718] [ 0. 0.693147]
--- np.log(x) without len --- <<[ 0. 0.693147], [ 0. 0.693147]>>
Run Code Online (Sandbox Code Playgroud)
如何创建满足要求的类(getitem,log,len)?我研究了子类化ndarray,但这似乎比它的价值更复杂.
另外,我找不到NumPy源代码中 x.__len__访问的位置,所以我也对此感兴趣.
编辑:我正在使用miniconda2与Python 2.7.11和NumPy 1.11.0.
| 归档时间: |
|
| 查看次数: |
801 次 |
| 最近记录: |