Pytorch 中的缓冲区是什么?

lum*_*uri 7 python pytorch

我了解register_buffer 的作用以及register_buffer 和 register_parameters之间的区别。

但是 PyTorch 中缓冲区的准确定义是什么?

Ber*_*iel 4

这可以通过查看实现来回答:

def register_buffer(self, name, tensor):
    if '_buffers' not in self.__dict__:
        raise AttributeError(
            "cannot assign buffer before Module.__init__() call")
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("buffer name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("buffer name can't contain \".\"")
    elif name == '':
        raise KeyError("buffer name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._buffers:
        raise KeyError("attribute '{}' already exists".format(name))
    elif tensor is not None and not isinstance(tensor, torch.Tensor):
        raise TypeError("cannot assign '{}' object to buffer '{}' "
                        "(torch Tensor or None required)"
                        .format(torch.typename(tensor), name))
    else:
        self._buffers[name] = tensor
Run Code Online (Sandbox Code Playgroud)

即缓冲区的名称:

  • 必须是一个字符串:not isinstance(name, torch._six.string_classes)
  • 不能包含.(点):'.' in name
  • 不能为空字符串:name == ''
  • 不能是模块的属性:hasattr(self, name)
  • 应该是唯一的:name not in self._buffers

以及tensor(你猜怎么着?):

  • 应该是一个张量:isinstance(tensor, torch.Tensor)

因此,缓冲区只是一个具有这些属性的张量,注册在_buffersa 的属性中Module