seb*_*-sz 6 python deep-learning keras tensorflow efficientnet
我想从预训练的 Keras 模型中删除前N层。例如, an EfficientNetB0
,其前3层仅负责预处理:
import tensorflow as tf
efinet = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)
print(efinet.layers[:3])
# [<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fa9a870e4d0>,
# <tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling at 0x7fa9a61343d0>,
# <tensorflow.python.keras.layers.preprocessing.normalization.Normalization at 0x7fa9a60d21d0>]
Run Code Online (Sandbox Code Playgroud)
正如M.Innat提到的,第一层是Input Layer
,应该保留或重新附加。我想删除这些层,但是像这样的简单方法会引发错误:
import tensorflow as tf
efinet = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)
print(efinet.layers[:3])
# [<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fa9a870e4d0>,
# <tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling at 0x7fa9a61343d0>,
# <tensorflow.python.keras.layers.preprocessing.normalization.Normalization at 0x7fa9a60d21d0>]
Run Code Online (Sandbox Code Playgroud)
这将导致:
ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(...)
Run Code Online (Sandbox Code Playgroud)
推荐的方法是什么?
出现错误的原因Graph disconnected
是因为您没有指定图层Input
。但这不是这里的主要问题。有时,使用API从模型中删除中间层keras
并不简单。Sequential
Functional
对于顺序,它相对应该很容易,而在功能模型中,您需要关心多输入块(例如multiply
,add
等等)。例如:如果您想删除顺序模型中的某些中间层,您可以轻松适应此解决方案。但对于功能模型( ),由于多输入内部模块,efficientnet
您不能这样做,并且您将遇到以下错误:。因此,这需要更多的工作,据我所知,这是一种可能的方法来克服它。ValueError: A merged layer should be called on a list of inputs
在这里,我将针对您的情况展示一个简单的解决方法,但它可能不通用,并且在某些情况下也不安全。即基于这种方法;使用pop
方法。为什么使用起来不安全!。好的,我们首先加载模型。
func_model = tf.keras.applications.EfficientNetB0()
for i, l in enumerate(func_model.layers):
print(l.name, l.output_shape)
if i == 8: break
input_19 [(None, 224, 224, 3)]
rescaling_13 (None, 224, 224, 3)
normalization_13 (None, 224, 224, 3)
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)
Run Code Online (Sandbox Code Playgroud)
接下来,使用.pop
方法:
func_model._layers.pop(1) # remove rescaling
func_model._layers.pop(1) # remove normalization
for i, l in enumerate(func_model.layers):
print(l.name, l.output_shape)
if i == 8: break
input_22 [(None, 224, 224, 3)]
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)
block1a_activation (None, 112, 112, 32)
block1a_se_squeeze (None, 32)
Run Code Online (Sandbox Code Playgroud)
归档时间: |
|
查看次数: |
4915 次 |
最近记录: |