如何从 Keras 模型中删除前 N 层?

seb*_*-sz 6 python deep-learning keras tensorflow efficientnet

我想从预训练的 Keras 模型中删除前N层。例如, an EfficientNetB0,其前3层仅负责预处理:

import tensorflow as tf

efinet = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)

print(efinet.layers[:3])
# [<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fa9a870e4d0>,
# <tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling at 0x7fa9a61343d0>,
# <tensorflow.python.keras.layers.preprocessing.normalization.Normalization at 0x7fa9a60d21d0>]
Run Code Online (Sandbox Code Playgroud)

正如M.Innat提到的,第一层是Input Layer,应该保留或重新附加。我想删除这些层,但是像这样的简单方法会引发错误:

import tensorflow as tf

efinet = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)

print(efinet.layers[:3])
# [<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fa9a870e4d0>,
# <tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling at 0x7fa9a61343d0>,
# <tensorflow.python.keras.layers.preprocessing.normalization.Normalization at 0x7fa9a60d21d0>]
Run Code Online (Sandbox Code Playgroud)

这将导致:

ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(...)
Run Code Online (Sandbox Code Playgroud)

推荐的方法是什么?

M.I*_*nat 2

出现错误的原因Graph disconnected是因为您没有指定图层Input。但这不是这里的主要问题。有时,使用API从模型中删除中间层keras并不简单。SequentialFunctional

对于顺序,它相对应该很容易,而在功能模型中,您需要关心多输入块(例如multiplyadd等等)。例如:如果您想删除顺序模型中的某些中间层,您可以轻松适应此解决方案。但对于功能模型( ),由于多输入内部模块,efficientnet您不能这样做,并且您将遇到以下错误:。因此,这需要更多的工作,据我所知,这是一种可能的方法来克服它。ValueError: A merged layer should be called on a list of inputs


在这里,我将针对您的情况展示一个简单的解决方法,但它可能不通用,并且在某些情况下也不安全。即基于这种方法;使用pop方法。为什么使用起来不安全!。好的,我们首先加载模型。

func_model = tf.keras.applications.EfficientNetB0()

for i, l in enumerate(func_model.layers):
    print(l.name, l.output_shape)
    if i == 8: break

input_19 [(None, 224, 224, 3)]
rescaling_13 (None, 224, 224, 3)
normalization_13 (None, 224, 224, 3)
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)
Run Code Online (Sandbox Code Playgroud)

接下来,使用.pop方法:

func_model._layers.pop(1) # remove rescaling
func_model._layers.pop(1) # remove normalization

for i, l in enumerate(func_model.layers):
    print(l.name, l.output_shape)
    if i == 8: break

input_22 [(None, 224, 224, 3)]
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)
block1a_activation (None, 112, 112, 32)
block1a_se_squeeze (None, 32)
Run Code Online (Sandbox Code Playgroud)