我正在尝试对批量标准化进行研究,并且不得不对 pytorch BN 代码进行一些修改。我深入研究了 pytorch 代码并陷入了困境torch.nn.functional.batch_norm
,其中引用了torch.batch_norm
.
问题是torch.batch_norm
无法在火炬库中进一步找到。有什么办法可以找到这个内置函数的源代码并重新实现它吗?谢谢!
它在那里,但它没有在 Python 中定义。它们在aten/
目录中的 C++ 中定义。
对于 CPU,实现(其中之一,取决于输入是否连续)在这里:https : //github.com/pytorch/pytorch/blob/420b37f3c67950ed93cd8aa7a12e673fcfc5567b/aten/src/ATen/native/Normalization。 cpp#L61-L126
对于 CUDA,实现在这里:https : //github.com/pytorch/pytorch/blob/7aae51cdedcbf0df5a7a8bf50a947237ac4b3ee8/aten/src/ATen/native/cudnn/BatchNorm.cpp#L52-L143
归档时间: |
|
查看次数: |
1302 次 |
最近记录: |