VF1*_*VF1 3 python recursion numpy
金融和强化学习中的常用术语是C[i]基于原始奖励的时间序列的折扣累积奖励R[i].给定一个数组R,我们想用(并返回完整的数组)计算C[i]满足重现.C[i] = R[i] + discount * C[i+1]C[-1] = R[-1]C
在numpy数组的python中计算这个数值稳定的方法可能是:
import numpy as np
def cumulative_discount(rewards, discount):
future_cumulative_reward = 0
assert np.issubdtype(rewards.dtype, np.floating), rewards.dtype
cumulative_rewards = np.empty_like(rewards)
for i in range(len(rewards) - 1, -1, -1):
cumulative_rewards[i] = rewards[i] + discount * future_cumulative_reward
future_cumulative_reward = cumulative_rewards[i]
return cumulative_rewards
Run Code Online (Sandbox Code Playgroud)
但是,这依赖于python循环.鉴于这是一个如此常见的计算,当然有一个现有的矢量化解决方案依赖于其他一些标准函数而不需要求助于cythonization.
请注意,使用类似内容的任何解决方案np.power(discount, np.arange(len(rewards))都不会很稳定.
您可以使用scipy.signal.lfilter来解决递归关系:
def alt(rewards, discount):
"""
C[i] = R[i] + discount * C[i+1]
signal.lfilter(b, a, x, axis=-1, zi=None)
a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
- a[1]*y[n-1] - ... - a[N]*y[n-N]
"""
r = rewards[::-1]
a = [1, -discount]
b = [1]
y = signal.lfilter(b, a, x=r)
return y[::-1]
Run Code Online (Sandbox Code Playgroud)
此脚本测试结果是否相同:
import scipy.signal as signal
import numpy as np
def orig(rewards, discount):
future_cumulative_reward = 0
cumulative_rewards = np.empty_like(rewards, dtype=np.float64)
for i in range(len(rewards) - 1, -1, -1):
cumulative_rewards[i] = rewards[i] + discount * future_cumulative_reward
future_cumulative_reward = cumulative_rewards[i]
return cumulative_rewards
def alt(rewards, discount):
"""
C[i] = R[i] + discount * C[i+1]
signal.lfilter(b, a, x, axis=-1, zi=None)
a[0]*y[n] = b[0]*x[n] + b[1]*x[n-1] + ... + b[M]*x[n-M]
- a[1]*y[n-1] - ... - a[N]*y[n-N]
"""
r = rewards[::-1]
a = [1, -discount]
b = [1]
y = signal.lfilter(b, a, x=r)
return y[::-1]
# test that the result is the same
np.random.seed(2017)
for i in range(100):
rewards = np.random.random(10000)
discount = 1.01
expected = orig(rewards, discount)
result = alt(rewards, discount)
if not np.allclose(expected,result):
print('FAIL: {}({}, {})'.format('alt', rewards, discount))
break
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
759 次 |
| 最近记录: |