如何在pytorch中查找内置函数源代码

Sto*_*ree 3 pytorch

我正在尝试对批量标准化进行研究,并且不得不对 pytorch BN 代码进行一些修改。我深入研究了 pytorch 代码并陷入了困境torch.nn.functional.batch_norm,其中引用了torch.batch_norm.

问题是torch.batch_norm无法在火炬库中进一步找到。有什么办法可以找到这个内置函数的源代码并重新实现它吗?谢谢!

Jos*_*rty 7

它在那里,但它没有在 Python 中定义。它们在aten/目录中的 C++ 中定义。

对于 CPU,实现(其中之一,取决于输入是否连续)在这里:https : //github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization。 cpp#L61-L126

对于 CUDA,实现在这里:https : //github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143

  • 我不知道有什么映射。我刚刚在 GitHub 上的存储库中搜索了“batch_norm”,并手动查看了结果。我密切关注“/aten”文件夹,因为那是大多数 C++ 源代码所在的位置。 (2认同)