pytorch中的高阶梯度

use*_*261 5 python gradient pytorch autograd

我在pytorch中实现了以下雅可比函数.除非我犯了一个错误,否则它会计算任何张量输入的任何张量的雅可比行列式:

import torch
import torch.autograd as ag

def nd_range(stop, dims = None):
    if dims == None:
        dims = len(stop)
    if not dims:
        yield ()
        return
    for outer in nd_range(stop, dims - 1):
        for inner in range(stop[dims - 1]):
            yield outer + (inner,)


def full_jacobian(f, wrt):    
    f_shape = list(f.size())
    wrt_shape = list(wrt.size())
    fs = []


    f_range = nd_range(f_shape)
    wrt_range = nd_range(wrt_shape)

    for f_ind in f_range:
        grad = ag.grad(f[tuple(f_ind)], wrt, retain_graph=True, create_graph=True)[0]
        for i in range(len(f_shape)):
            grad = grad.unsqueeze(0)
        fs.append(grad)

    fj = torch.cat(fs, dim=0)
    fj = fj.view(f_shape + wrt_shape)
    return fj
Run Code Online (Sandbox Code Playgroud)

最重要的是,我试图实现一个递归函数来计算n阶导数:

def nth_derivative(f, wrt, n):
    if n == 1:
        return full_jacobian(f, wrt)
    else:        
        deriv = nth_derivative(f, wrt, n-1)
        return full_jacobian(deriv, wrt)
Run Code Online (Sandbox Code Playgroud)

我做了一个简单的测试:

op = torch.ger(s, s)
deep_deriv = nth_derivative(op, s, 5)
Run Code Online (Sandbox Code Playgroud)

不幸的是,这成功地让我成为了Hessian ......但没有更高阶的衍生品.我知道许多高阶导数应该是0,但我更喜欢pytorch可以分析计算它.

一个修复方法是将渐变计算更改为:

try:
            grad = ag.grad(f[tuple(f_ind)], wrt, retain_graph=True, create_graph=True)[0]
        except:
            grad = torch.zeros_like(wrt)
Run Code Online (Sandbox Code Playgroud)

这是接受的正确方法吗?或者有更好的选择吗?或者我是否有理由认为我的问题完全错误?

Ale*_*lex 8

你可以迭代调用grad函数:

import torch
from torch.autograd import grad

def nth_derivative(f, wrt, n):

    for i in range(n):

        grads = grad(f, wrt, create_graph=True)[0]
        f = grads.sum()

    return grads

x = torch.arange(4, requires_grad=True).reshape(2, 2)
loss = (x ** 4).sum()

print(nth_derivative(f=loss, wrt=x, n=3))
Run Code Online (Sandbox Code Playgroud)

输出

tensor([[  0.,  24.],
        [ 48.,  72.]])
Run Code Online (Sandbox Code Playgroud)

  • “这样的迭代方法会影响性能吗?” 是一个很模糊的问题。它可能会或可能不会取决于您的代码的其余部分。除非您遇到性能问题并确定这是一个瓶颈,否则不要担心。[引用 Donald Knuth](http://wiki.c2.com/?PrematureOptimization) “过早的优化是万恶之源”。 (3认同)
  • 这里有一个大问题;很惊讶之前没有人提到过。求和,然后求导,得到二阶导数加上交叉导数的总和(随着 for 循环的继续,它们会累积更多)。在上面的示例中(即 x**4),所有这些交叉导数在构造上均为零。 (3认同)
  • @ user650261 “不是最干净的解决方案”是什么意思? (2认同)