torch.flatten() 和 nn.Flatten() 的区别

jul*_*les 5 python neural-network difference pytorch tensor

torch.flatten()和之间有什么区别torch.nn.Flatten()

Iva*_*van 6

扁平化在 PyTorch 中以三种形式提供

  • 作为一种张量的方法(OOP样式torch.Tensor.flatten直接在一个张量施加:x.flatten()

  • 作为一个函数(函数形式torch.flatten应用为:torch.flatten(x)

  • 作为一个模块(nn.Modulenn.Flatten()。通常用于模型定义。

所有这三个是相同的并且共享相同的实施方式中,唯一的区别是nn.Flatten已经start_dim设置为1默认,以避免平坦化所述第一轴线(通常是分批轴)。而其他两个从axis=0to变平axis=-1-整个张量 - 如果没有给出参数。


kma*_*o23 5

您可以将这项工作torch.flatten()视为简单地对张量进行展平操作,而不附加任何条件。你给出一个张量,它会展平,然后返回它。这就是全部了。

相反,nn.Flatten()它要复杂得多(即,它是一个神经网络层)。作为面向对象的,它继承自nn.Module,尽管它在内部使用普通的tensor.flatten() OP 方法来forward()平张量。您可以将其视为语法糖 over torch.flatten()


重要区别:一个显着的区别是,只要输入torch.flatten() 至少为 1D 或更大,则nn.Flatten() 始终返回 1D 张量作为结果,而只要输入至少为 2D 或更大,则始终返回 2D 张量(以 1D 张量作为输入,它会抛出一个IndexError)。


比较:

  • torch.flatten()是 API,而nn.Flatten()是神经网络层。

  • torch.flatten()是一个 python函数,而nn.Flatten()是一个 python

  • 由于以上这一点,nn.Flatten()附带了很多方法和属性

  • torch.flatten()可以在野外使用(例如,对于简单的张量OP),而预计在块中作为层之一nn.Flatten()使用。nn.Sequential()

  • torch.flatten()没有有关计算图的信息,除非它被卡在其他图感知块中(tensor.requires_grad标志设置为True),而nn.Flatten()始终由 autograd 跟踪。

  • torch.flatten()无法接受和处理(例如,线性/Conv1D)层作为输入,而nn.Flatten()主要用于处理这些神经网络层。

  • 两者torch.flatten()nn.Flatten()返回输入张量的视图。因此,对结果的任何修改也会影响输入张量。(见下面的代码)


代码演示

# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1)   # 3D tensor
Run Code Online (Sandbox Code Playgroud)

压平torch.flatten()

In [113]: t1flat = torch.flatten(t1)

In [114]: t1flat
Out[114]: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

# modification to the flattened tensor    
In [115]: t1flat[-1] = -1

# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]: 
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, -1]])
Run Code Online (Sandbox Code Playgroud)

压平nn.Flatten()

In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)

# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]: 
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27],
        [28, 29, 30, 31, 32, 33, 34, 35]])

# modification to the result
In [126]: t3flat[-1, -1] = -1

# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]: 
tensor([[[12, 13, 14, 15],
         [16, 17, 18, 19]],

        [[20, 21, 22, 23],
         [24, 25, 26, 27]],

        [[28, 29, 30, 31],
         [32, 33, 34, -1]]])
Run Code Online (Sandbox Code Playgroud)

花絮:是它及其兄弟torch.flatten()的前身,因为它从一开始就存在。然后,有一个合法的用例,因为这是几乎所有 ConvNet 的共同要求(就在 softmax 之前或其他地方)。所以它后来被添加到PR #22245中。nn.Flatten() nn.Unflatten()nn.Flatten()

最近还有在 ResNet 中用于nn.Flatten()模型手术的建议。