zca*_*dqe 12 python deep-learning keras tensorflow
我拥有的代码(我无法更改)使用Resnet with my_input_tensor作为input_tensor.
model1 = keras.applications.resnet50.ResNet50(input_tensor=my_input_tensor, weights='imagenet')
Run Code Online (Sandbox Code Playgroud)
调查源代码,ResNet50函数创建一个新的keras输入层,my_input_tensor然后创建模型的其余部分.这是我想用自己的模型复制的行为.我从h5文件加载我的模型.
model2 = keras.models.load_model('my_model.h5')
Run Code Online (Sandbox Code Playgroud)
由于此模型已经有一个输入层,我想用一个新的输入层替换它my_input_tensor.
如何更换输入图层?
Mil*_*ore 23
使用以下方法保存模型时:
old_model.save('my_model.h5')
Run Code Online (Sandbox Code Playgroud)
它会节省以下内容:
那么,当你加载模型时:
res50_model = load_model('my_model.h5')
Run Code Online (Sandbox Code Playgroud)
你应该得到相同的型号,你可以使用以下方法验证:
res50_model.summary()
res50_model.get_weights()
Run Code Online (Sandbox Code Playgroud)
现在,您可以使用以下命令弹出输入图层并添加自己的:
res50_model.layers.pop(0)
res50_model.summary()
Run Code Online (Sandbox Code Playgroud)
添加新的输入图层:
newInput = Input(batch_shape=(0,299,299,3)) # let us say this new InputLayer
newOutputs = res50_model(newInput)
newModel = Model(newInput, newOutputs)
newModel.summary()
res50_model.summary()
Run Code Online (Sandbox Code Playgroud)
不幸的是,@MilindDeore 的解决方案对我不起作用。虽然我可以打印新模型的摘要,但在预测时收到“矩阵大小不兼容”错误。我想这是有道理的,因为密集层的新输入形状与旧密集层权重的形状不匹配。
因此,这是另一种解决方案。我的关键是使用“_layers”而不是“layers”。后者似乎只返回一个副本。
import keras
import numpy as np
def get_model():
old_input_shape = (20, 20, 3)
model = keras.models.Sequential()
model.add(keras.layers.Conv2D(9, (3, 3), padding="same", input_shape=old_input_shape))
model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(1, activation="sigmoid"))
model.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(lr=0.0001), metrics=['acc'], )
model.summary()
return model
def change_model(model, new_input_shape=(None, 40, 40, 3)):
# replace input shape of first layer
model._layers[1].batch_input_shape = new_input_shape
# feel free to modify additional parameters of other layers, for example...
model._layers[2].pool_size = (8, 8)
model._layers[2].strides = (8, 8)
# rebuild model architecture by exporting and importing via json
new_model = keras.models.model_from_json(model.to_json())
new_model.summary()
# copy weights from old model to new one
for layer in new_model.layers:
try:
layer.set_weights(model.get_layer(name=layer.name).get_weights())
except:
print("Could not transfer weights for layer {}".format(layer.name))
# test new model on a random input image
X = np.random.rand(10, 40, 40, 3)
y_pred = new_model.predict(X)
print(y_pred)
return new_model
if __name__ == '__main__':
model = get_model()
new_model = change_model(model)
Run Code Online (Sandbox Code Playgroud)
Layers.pop(0) 或类似的东西不起作用。
您有两个选项可以尝试:
1.
您可以创建具有所需层的新模型。
一个相对简单的方法是 i) 提取模型 json 配置,ii) 适当更改它,iii) 从中创建一个新模型,然后 iv) 复制权重。我只会展示基本的想法。
i) 提取配置
model_config = model.get_config()
Run Code Online (Sandbox Code Playgroud)
ii) 更改配置
input_layer_name = model_config['layers'][0]['name']
model_config['layers'][0] = {
'name': 'new_input',
'class_name': 'InputLayer',
'config': {
'batch_input_shape': (None, 300, 300),
'dtype': 'float32',
'sparse': False,
'name': 'new_input'
},
'inbound_nodes': []
}
model_config['layers'][1]['inbound_nodes'] = [[['new_input', 0, 0, {}]]]
model_config['input_layers'] = [['new_input', 0, 0]]
Run Code Online (Sandbox Code Playgroud)
ii) 创建一个新模型
new_model = model.__class__.from_config(model_config, custom_objects={}) # change custom objects if necessary
Run Code Online (Sandbox Code Playgroud)
ii) 复制权重
# iterate over all the layers that we want to get weights from
weights = [layer.get_weights() for layer in model.layers[1:]]
for layer, weight in zip(new_model.layers[1:], weights):
layer.set_weights(weight)
Run Code Online (Sandbox Code Playgroud)
2.
你可以尝试像kerassurgeon这样的库(我正在链接到一个与 tensorflow keras 版本一起使用的叉子)。请注意,插入和删除操作仅在某些条件下有效,例如兼容维度。
from kerassurgeon.operations import delete_layer, insert_layer
model = delete_layer(model, layer_1)
# insert new_layer_1 before layer_2 in a model
model = insert_layer(model, layer_2, new_layer_3)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
10446 次 |
| 最近记录: |