我快要疯了。
我使用 tensorflow keras 定义了一个顺序模型:
model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
tf.keras.layer.Dense(10,activation="softmax"])
model.compile(optimizer="adam",loss="mse")
keras.experimental.export_saved_model(model,"keras_model")
Run Code Online (Sandbox Code Playgroud)
我使用 c_api.h在C 程序中训练所述模型
C 程序将权重保存在检查点文件中。
尝试从检查点文件中恢复 python 中的权重时:
keras.experimental.load_from_saved_model("keras_model/")
#OR
model = tf.keras.Sequential([tf.keras.layer.Dense(128,input_shape(784,),activation="relu"),
tf.keras.layer.Dense(10,activation="softmax"])
model.load_weights("keras_model/variables/variables")
#OR
checkpoint = tf.train.Checkpoint(model=model)
status = checkpoint.restore("keras_model/variables/variables")
Run Code Online (Sandbox Code Playgroud)
我最终得到一个错误并且没有恢复权重。
我能够恢复重量并继续在我的 C 程序中训练
keras.experimental.load_from_saved_model("keras_model/")
WARNING: Logging before flag parsing goes to stderr.
W0918 15:18:04.350199 140418474760000 deprecation.py:323] From <ipython-input-2-06ea110fdc8e>:1: load_from_saved_model (from tensorflow.python.keras.saving.saved_model_experimental) is deprecated and will be removed in a future version.
Instructions for updating:
The experimental save and load functions have been deprecated. Please …Run Code Online (Sandbox Code Playgroud) 我正在探索 tensorflow 2.0 的 c API。
问题:当将模型加载到 python 时,权重没有恢复,因此模型似乎未经训练。
工作流程:我使用 TF 2.0 C api 来处理我的模型的训练。我遵循的一般设置是:
1.使用TF keras api在python中定义模型。
import tensorflow as tf
from tensorflow import keras
model = keras.Sequential([keras.layers.Dense(128,
input_shape=(784,),
activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss="categorical_crossentropy",
metrics=['accuracy'])
keras.experimental.export_saved_model(model,"keras_model")
Run Code Online (Sandbox Code Playgroud)
我正在使用 keras.experimental.export_saved_model() 因为我需要使用 keras.Model.save() 时未保存的“signature_def['train']”。
2. 使用 TF 2.0 C api 在 C 中训练模型 保存的模型然后通过以下方式加载到我的 C 程序中:
TF_LoadSessionFromSavedModel()
Run Code Online (Sandbox Code Playgroud)
随后对其进行训练并保存检查点:
TF_SessionRun()
Run Code Online (Sandbox Code Playgroud)
保存模型会在存储模型的“变量”文件夹中创建新的检查点文件(“checkpoint.index”和“checkpoint.data-00000-of-00001”)。
3. 问题在python中重新加载模型 训练后我在python中重新加载我的模型。这是我发现加载的模型具有与未经训练的模型相对应的 wights 的地方。我知道这一点,因为当我在 C 中训练的模型准确预测时,预测是胡言乱语。我通过以下方式加载我的模型:
import tensorflow as tf
from tensorflow import keras
model = keras.experimental.load_from_saved_model("keras_model")
Run Code Online (Sandbox Code Playgroud)
同样,我使用 keras.experimental.load_from_saved_model() 因为当我使用 …