我正在尝试保存下面经过训练的模型。
resnet = ResNet50V2(input_shape=(im_size,im_size,3), weights='imagenet', include_top=False)
headModel = AvgPool2D(pool_size=(3,3))(resnet.output)
headModel = Flatten(name="flatten")(headModel)
headModel = Dense(256, activation="relu")(headModel)
headModel = Dropout(0.5)(headModel)
headModel = Dense(1, activation="sigmoid")(headModel)
resnet50v2 = Model(inputs=resnet.input, outputs=headModel)
resnet50v2.compile(loss='binary_crossentropy', optimizer=opt, metrics=METRICS)
history = resnet50v2.fit(
datagen.flow(X_train, y_train, batch_size=32, subset='training'),
batch_size=batch_size,
epochs=150,
steps_per_epoch=steps_per_epoch,
validation_data=datagen.flow(X_train, y_train, batch_size=8, subset='validation'))
Run Code Online (Sandbox Code Playgroud)
但是,每当我尝试使用以下命令保存它时:
resnet50v2.save('Saved_Models/resnet50.h5', save_format='h5')
我收到错误
ValueError Traceback (most recent call last)
/tmp/ipykernel_3252071/2034094124.py in <module>
----> 1 resnet50v2.save('Saved_Models/resnet50.h5', save_format='h5')
~/.local/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise …Run Code Online (Sandbox Code Playgroud) 我已经实现了以下版本的 ResNet50。我在另一个笔记本中使用自己的数据训练了模型,因此我只需加载权重并编译模型即可。现在,我只想对我未见过的新数据进行预测。
def resnet50F(im_size):
resnet = ResNet50(input_shape=(im_size, im_size, 3), weights='imagenet', include_top = False)
headModel = AvgPool2D(pool_size=(3,3))(resnet.output)
headModel = Flatten(name='flatten')(headModel)
headModel = Dense(256, activation='relu')(headModel)
headModel = Dropout(0.5)(headModel)
headModel = Dense(1, activation='sigmoid')(headModel)
model = Model(inputs=resnet.input, outputs=headModel)
model.trainable = True
return model
resnet50 = resnet50F(im_size=224)
resnet50.load_weights(PATH_MODEL_WEIGHTS)
opt = optimizers.Adam(learning_rate=1e-6)
resnet50.compile(loss='binary_crossentropy', optimizer=opt, metrics=METRICS)
predictions = resnet50.predict(X)
Run Code Online (Sandbox Code Playgroud)
但是,当我打印时predictions,我得到以下输出:
[[4.22752373e-06]
[2.81104029e-10]
[3.21204737e-02]
[5.09007333e-12]
[6.25871266e-08]
[3.95518853e-08]
[3.76289577e-09]
[1.04685043e-07]
[4.40788448e-01]
[4.18029167e-09]
[1.68976447e-04]
[4.83552366e-03]
[5.67837298e-01]
[1.92822833e-02]
[1.86168763e-04]
[3.30054699e-11]
[1.55285016e-01]
[1.40850764e-12]
[4.75460291e-02]
[2.36899691e-08]
[1.91837142e-04]
[2.70789745e-03]
[2.28864295e-07]
[1.04725331e-08]
[3.17185315e-15] …Run Code Online (Sandbox Code Playgroud)