在 PyTorch 中,grad_fn 属性到底存储什么以及如何使用?

Dav*_*ian 18 python oop pytorch

在 PyTorch 中,Tensor类有一个grad_fn属性。这引用了用于获取张量的操作:例如, if a = b + 2, a.grad_fnwill be AddBackward0。但“参考”到底是什么意思呢?

检查AddBackward0usinginspect.getmro(type(a.grad_fn))将指出 的唯一基类AddBackward0object此外,在源代码grad_fn中找不到该类的源代码(事实上,在 中可能遇到的任何其他类) !

所有这些让我想到以下问题:

  1. 反向传播期间到底存储了什么grad_fn以及如何调用它?
  2. 为什么存储的对象grad_fn没有某种通用的超类,为什么 GitHub 上没有它们的源代码?

Pru*_*une 7

grad_fn是一个函数“句柄”,可以访问适用的梯度函数。给定点的梯度是反向传播时调整权重的系数。

“句柄”是对象描述符的通用术语,旨在提供对对象的适当访问。例如,当您打开文件时,open返回文件句柄。当您实例化一个类时,该__init__函数将返回所创建实例的句柄。句柄包含对相关项目的数据和函数的引用(通常是内存地址)。

它显示为泛型object类,因为它来自另一种语言的底层实现,因此它不完全映射到 Pythonfunction类型。PyTorch 处理跨语言调用和返回。这种切换是预编译(共享对象)运行时系统的一部分。

这足以澄清你所看到的吗?