即使我们不使用model.fit,何时应该继承keras.Model而不是keras.layers.Layer?

yun*_*abe 5 keras tensorflow tf.keras tensorflow2.0

在一些使用tf2的Tensorflow教程中(例如,具有AttentionEager要点的神经机器翻译),它们定义了custom tf.keras.Model而不是tf.keras.layers.Layers(例如BahdanauAttention(tf.keras.Model):

另外,模型:组成层文档tf.keras.Model明确使用。本节说:

创建包含其他图层的类似图层的东西时使用的主要类是tf.keras.Model。通过从tf.keras.Model继承来实现一个。

听起来我们需要继承tf.keras.Model以定义组成子图层的图层。

但是,据我检查,即使我将定义ResnetIdentityBlock为的子类,此代码也可以使用tf.keras.layers.Layer。其他两个教程也可以使用Layer。除此之外,另一个教程

Model is just like a Layer, but with added training and serialization utilities.

Thus, I have no idea what is the real difference between tf.keras.Model and tf.keras.layers.Layer and why those three tutorial with Eager execution uses tf.keras.Model though they don't use training and serialization utilities of tf.keras.Model.

Why do we need to inherit tf.keras.Model in those tutorials?

Additional comment

utilities of Model work only with special subsets of Layer (Layers whose call receive only one input). Thus, I think the idea like "Always extend Model because Model has more features" is not correct. Also, it violates a basic programming program like SRP.

eug*_*gen 1

更新

所以评论是:Yes, I know training and serialization utilities exist in Model as I wrote in the question. My question is why TF tutorials need to use Model though they don't use these methods.

在这种情况下,作者可以提供最佳答案,因为您的问题是问为什么他们选择一种方法而不是另一种方法,而他们都可以同样出色地完成工作。为什么能同样出色地完成工作?嗯,因为Model is just like a Layer, but with added training and serialization utilities.

我们可以说,当只有层可以完成工作时使用模型是一种矫枉过正,但这可能是一个品味问题。

希望能帮助到你

附言。

在您提供的热切示例自定义图层编写教程中,我们无法用图层替换模型,因此这些教程不适用于您的问题


使用模型,您可以训练,但仅使用图层,则不能。请参阅下面的方法列表(不包括内部和继承的方法):

tf.keras.layers.Layer

activity_regularizer
activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_mask
compute_output_shape
count_params
dtype
dynamic
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
trainable
trainable
trainable_variables
trainable_weights
updates
variables
weights
Run Code Online (Sandbox Code Playgroud)

看?那里没有适合或评估方法。 tf.keras.Model


compile
evaluate
evaluate_generator
fit
fit_generator
get_weights
load_weights
metrics
metrics_names
predict
predict_generator
predict_on_batch
reset_metrics
run_eagerly
run_eagerly
sample_weights
test_on_batch
train_on_batch
Run Code Online (Sandbox Code Playgroud)