在 Keras 中,使用Flatten()层会保留批量大小。例如,如果 Flatten 的输入形状为(32, 100, 100),则KerasFlatten 的输出为(32, 10000),但在 PyTorch 中为320000。为什么会这样呢?
jod*_*dag 12
正如 OP 在他们的回答中已经指出的那样,张量操作不会默认考虑批量维度。您可以使用torch.flatten()或Tensor.flatten()withstart_dim=1在批量维度之后开始展平操作。
或者,从 PyTorch 1.2.0 开始,您可以nn.Flatten()在模型中定义一个默认为start_dim=1.
| 归档时间: |
|
| 查看次数: |
3512 次 |
| 最近记录: |