Tab*_*bin 10 python numpy tensorflow
我正在编写一个小代码来使用张量流中的有限差分方法计算四阶导数。如下:
def action(y,x):
#spacing between points.
h = (x[-1] - x[0]) / (int(x.shape[0]) - 1)
#fourth derivative
dy4 = (y[4:] - 4*y[3:-1] + 6*y[2:-2] - 4*y[1:-3] + y[:-4])/(h*h*h*h)
return dy4
x = tf.linspace(0.0, 30, 1000)
y = tf.tanh(x)
dy4 = action(y,x)
sess = tf.compat.v1.Session()
plt.plot(sess.run(dy4))
Run Code Online (Sandbox Code Playgroud)
结果如下图所示:
但是,如果我使用基本相同的代码但仅使用 numpy,结果会更清晰:
def fourth_deriv(y, x):
h = (x[-1] - x[0]) / (int(x.shape[0]) - 1)
dy = (y[4:] - 4*y[3:-1] + 6*y[2:-2] - 4*y[1:-3] + y[:-4])/(h*h*h*h)
return dy
x = np.linspace(0.0, 30, 1000)
test = fourth_deriv(np.tanh(x), x)
plt.plot(test)
Run Code Online (Sandbox Code Playgroud)
这使:
这里有什么问题呢?我一开始认为点之间的间隔可能太小而无法进行准确的计算,但显然,如果 numpy 可以很好地处理它,情况就不是这样了。
ζ--*_*ζ-- 10
该问题与浮点类型的选择有关。
tf.linspace自动选择tf.float32其类型,而np.linspace创建一个float64数组,它的精度更高。进行如下修改:
start = tf.constant(0.0, dtype = tf.float64)
end = tf.constant(30.0, dtype = tf.float64)
x = tf.linspace(start, end, 1000)
Run Code Online (Sandbox Code Playgroud)
进一步值得注意的是,Tensorflow 确实包含自动微分,这对于机器学习训练至关重要,因此经过了充分测试 - 您可以使用梯度磁带访问它并评估四阶导数,而不会出现使用有限差分的数值微分的不精确性:
with tf.compat.v1.Session() as sess2:
x = tf.Variable(tf.linspace(0, 30, 1000))
sess2.run(tf.compat.v1.initialize_all_variables())
with tf.GradientTape() as t4:
with tf.GradientTape() as t3:
with tf.GradientTape() as t2:
with tf.GradientTape() as t1:
y = tf.tanh(x)
der1 = t1.gradient(y, x)
der2 = t2.gradient(der1, x)
der3 = t3.gradient(der2, x)
der4 = t4.gradient(der3, x)
print(der4)
plt.plot(sess2.run(der4))
Run Code Online (Sandbox Code Playgroud)
该方法的精度远远优于使用有限差分方法所能达到的精度。下面的代码比较了auto diff的精度和有限差分法的精度:
x = np.linspace(0.0, 30, 1000)
sech = 1/np.cosh(x)
theoretical = 16*np.tanh(x) * np.power(sech, 4) - 8*np.power(np.tanh(x), 3)*np.power(sech,2)
finite_diff_err = theoretical[2:-2] - from_finite_diff
autodiff_err = theoretical[2:-2] - from_autodiff[2:-2]
print('Max err with autodiff: %s' % np.max(np.abs(autodiff_err)))
print('Max err with finite difference: %s' % np.max(np.abs(finite_diff_err)))
line, = plt.plot(np.log10(np.abs(autodiff_err)))
line.set_label('Autodiff log error')
line2, = plt.plot(np.log10(np.abs(finite_diff_err)))
line2.set_label('Finite difference log error')
plt.legend()
Run Code Online (Sandbox Code Playgroud)
并产生以下输出:
Max err with autodiff: 3.1086244689504383e-15
Max err with a finite difference: 0.007830900165363808
Run Code Online (Sandbox Code Playgroud)