我了解register_buffer 的作用以及register_buffer 和 register_parameters之间的区别。
但是 PyTorch 中缓冲区的准确定义是什么?
这可以通过查看实现来回答:
def register_buffer(self, name, tensor):
if '_buffers' not in self.__dict__:
raise AttributeError(
"cannot assign buffer before Module.__init__() call")
elif not isinstance(name, torch._six.string_classes):
raise TypeError("buffer name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("buffer name can't contain \".\"")
elif name == '':
raise KeyError("buffer name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, torch.Tensor):
raise TypeError("cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)"
.format(torch.typename(tensor), name))
else:
self._buffers[name] = tensor
Run Code Online (Sandbox Code Playgroud)
即缓冲区的名称:
not isinstance(name, torch._six.string_classes).(点):'.' in namename == ''hasattr(self, name)name not in self._buffers以及tensor(你猜怎么着?):
isinstance(tensor, torch.Tensor)因此,缓冲区只是一个具有这些属性的张量,注册在_buffersa 的属性中Module;