Dan*_*ous 5 python deep-learning keras tensorflow
我正在使用来自 keras.applications 的 ResNet50 和 DenseNet121 构建模型融合,但在保存模型时引发错误。如果我只使用 ResNet50 和 DenseNet121 的一个网络,例如 DenseNet only,没问题
与 ResNet50 和 DenseNet121 的融合:
img_input = Input(shape=input_shape)
densenet = app.DenseNet121(
include_top=False,
input_tensor=img_input,
input_shape=input_shape,
weights=base_weights)
resnet = app.ResNet50(
include_top=False,
input_tensor=img_input,
input_shape=input_shape,
weights=base_weights)
x1 = densenet.output
x1 = GlobalAveragePooling2D(name='dn_gap_last')(x1)
# then x1.shape is (batch, 1024)
x2 = resnet.output
x2 = Flatten()(x2) # then x2.shape is (batch, 2048)
x = concatenate([x1, x2], axis=-1)
predictions = Dense(len(class_names), activation="sigmoid", name="predictions")(x)
model = Model(inputs=img_input, outputs=predictions)
Run Code Online (Sandbox Code Playgroud)
并通过 ModelCheckpoint 保存模型
checkpoint = ModelCheckpoint(
output_weights_path,
save_weights_only=True,
save_best_only=True,
verbose=1,
)
Run Code Online (Sandbox Code Playgroud)
但在保存 mdoel 时引发错误
Epoch 00001: val_loss improved from inf to 0.72018, saving model to ./experiments/8/weights.h5
Traceback (most recent call last):
File "train.py", line 229, in <module>
main()
File "train.py", line 212, in main
shuffle=False,
File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/engine/training.py", line 2280, in fit_generator
callbacks.on_epoch_end(epoch, epoch_logs)
File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/callbacks.py", line 77, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/callbacks.py", line 445, in on_epoch_end
self.model.save_weights(filepath, overwrite=True)
File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/engine/topology.py", line 2607, in save_weights
save_weights_to_hdf5_group(f, self.layers)
File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/keras/engine/topology.py", line 2878, in save_weights_to_hdf5_group
g = f.create_group(layer.name)
File "/home/hqt/chest-x-ray-project/code/venv/lib/python3.6/site-packages/h5py/_hl/group.py", line 50, in create_group
gid = h5g.create(self.id, name, lcpl=lcpl)
File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
File "h5py/h5g.pyx", line 151, in h5py.h5g.create
ValueError: Unable to create group (name already exists)
Run Code Online (Sandbox Code Playgroud)
如果我需要像您的情况一样使用像 tf.tile 这样的操作,我将使用 lambda 层来调用它。所以有效的代码如下
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model
def my_fun(a):
out = tf.tile(a, (1, tf.shape(a)[0]))
return out
a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)
x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())
model.save('my_model.h5')
#load the model
new_model=tf.keras.models.load_model("my_model.h5")
Run Code Online (Sandbox Code Playgroud)
任何遇到类似问题的人,请关注与此问题相关的GitHub 问题以获得最终解决方案。谢谢!
通过最近的代码修改,您可以使用tf-nightly“h5”格式保存模型,而不会出现任何问题,如下所示。
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras import Model
a = Input(shape=(10,))
out = tf.tile(a, (1, tf.shape(a)[0]))
model = Model(a, out)
x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())
model.save('./my_model', save_format='h5')
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
3320 次 |
| 最近记录: |