我最近不得不构建一个需要包含张量的模块。虽然反向传播使用 完美运行torch.nn.Parameter,但在打印网络对象时它没有出现。parameter与其他模块相比,为什么不包含它layer?(它不应该表现得像layer吗?)
import torch
import torch.nn as nn
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.layer = nn.Linear(10, 10)
self.parameter = torch.nn.Parameter(torch.zeros(10,10, requires_grad=True))
net = MyNet()
print(net)
Run Code Online (Sandbox Code Playgroud)
输出:
MyNet(
(layer): Linear(in_features=10, out_features=10, bias=True)
)
Run Code Online (Sandbox Code Playgroud)
当您调用 时print(net),将调用该__repr__方法。__repr__给出对象的“官方”字符串表示。
在 PyTorch 的nn.Module(MyNet模型的基类)中,__repr__实现如下:
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
Run Code Online (Sandbox Code Playgroud)
请注意,上述方法返回main_strwhich 只包含对_modulesand 的调用extra_repr,因此默认情况下它仅打印模块。
PyTorch 还提供了extra_repr()您可以自己实现的方法,用于模块的额外表示。
要打印自定义的额外信息,您应该在自己的模块中重新实现此方法。单行和多行字符串都可以接受。
| 归档时间: |
|
| 查看次数: |
2611 次 |
| 最近记录: |