jul*_*les 5 python neural-network difference pytorch tensor
torch.flatten()和之间有什么区别torch.nn.Flatten()?
扁平化在 PyTorch 中以三种形式提供
作为一种张量的方法(OOP样式)torch.Tensor.flatten直接在一个张量施加:x.flatten()。
作为一个函数(函数形式)torch.flatten应用为:torch.flatten(x)。
作为一个模块(层nn.Module)nn.Flatten()。通常用于模型定义。
所有这三个是相同的并且共享相同的实施方式中,唯一的区别是nn.Flatten已经start_dim设置为1默认,以避免平坦化所述第一轴线(通常是分批轴)。而其他两个从axis=0to变平axis=-1-即整个张量 - 如果没有给出参数。
您可以将这项工作torch.flatten()视为简单地对张量进行展平操作,而不附加任何条件。你给出一个张量,它会展平,然后返回它。这就是全部了。
相反,nn.Flatten()它要复杂得多(即,它是一个神经网络层)。作为面向对象的,它继承自nn.Module,尽管它在内部使用普通的tensor.flatten() OP 方法来forward()展平张量。您可以将其视为语法糖 over torch.flatten()。
重要区别:一个显着的区别是,只要输入torch.flatten() 至少为 1D 或更大,则nn.Flatten() 始终返回 1D 张量作为结果,而只要输入至少为 2D 或更大,则始终返回 2D 张量(以 1D 张量作为输入,它会抛出一个IndexError)。
torch.flatten()是 API,而nn.Flatten()是神经网络层。
torch.flatten()是一个 python函数,而nn.Flatten()是一个 python类。
由于以上这一点,nn.Flatten()附带了很多方法和属性
torch.flatten()可以在野外使用(例如,对于简单的张量OP),而预计在块中作为层之一nn.Flatten()使用。nn.Sequential()
torch.flatten()没有有关计算图的信息,除非它被卡在其他图感知块中(tensor.requires_grad标志设置为True),而nn.Flatten()始终由 autograd 跟踪。
torch.flatten()无法接受和处理(例如,线性/Conv1D)层作为输入,而nn.Flatten()主要用于处理这些神经网络层。
两者torch.flatten()都nn.Flatten()返回输入张量的视图。因此,对结果的任何修改也会影响输入张量。(见下面的代码)
代码演示:
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor
Run Code Online (Sandbox Code Playgroud)
压平torch.flatten():
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])
Run Code Online (Sandbox Code Playgroud)
压平nn.Flatten():
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])
Run Code Online (Sandbox Code Playgroud)
花絮:是它及其兄弟torch.flatten()的前身,因为它从一开始就存在。然后,有一个合法的用例,因为这是几乎所有 ConvNet 的共同要求(就在 softmax 之前或其他地方)。所以它后来被添加到PR #22245中。nn.Flatten() nn.Unflatten()nn.Flatten()
最近还有在 ResNet 中用于nn.Flatten()模型手术的建议。