为什么当作为参数传递给 Dense 层时,“input_shape”不包括批次维度?

Jen*_*sun 7 python keras tensorflow

在 Keras 中,为什么input_shape当作为参数传递给层时不包括批处理维度,Dense但在input_shape传递给build模型的方法时包含批处理维度?

import tensorflow as tf
from tensorflow.keras.layers import Dense

if __name__ == "__main__":
    model1 = tf.keras.Sequential([Dense(1, input_shape=[10])])
    model1.summary()

    model2 = tf.keras.Sequential([Dense(1)])
    model2.build(input_shape=[None, 10])  # why [None, 10] and not [10]?
    model2.summary()
Run Code Online (Sandbox Code Playgroud)

这是 API 设计的有意识选择吗?如果是,为什么?

rvi*_*nas 7

您可以通过多种不同的方式指定模型的输入形状。例如,通过向模型的第一层提供以下参数之一:

  • batch_input_shape:一个元组,其中第一个维度是批量大小。
  • input_shape:不包含批量大小的元组,例如,如果指定的话,批量大小假定为Nonebatch_size
  • input_dim:表示输入维度的标量。

在所有这些情况下,Keras 在内部存储一个属性_batch_input_size来构建模型。

关于该build方法,我的猜测是,这确实是一个有意识的选择 - 有关批量大小的信息可能有助于在某些(可能是未想到的)情况下构建模型。因此,包含批次维度作为输入的框架build比不包含批次维度的框架更加通用和完整。尽管如此,我同意你的观点,即命名论证batch_input_shape而不是input_shape会使一切更加一致。


build还值得一提的是,用户很少需要自己调用该方法。当需要时,这会在内部发生。如今,甚至可以在创建模型时忽略参数(尽管类似的方法在模型构建完成后才起作用)。在这种情况下,Keras 能够从 的参数推断输入形状。input_shapesummaryxfit