我在以下代码中创建了一个 tf-agent DqnAgent:
tf_agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
train_step_counter=train_step_counter
Run Code Online (Sandbox Code Playgroud)
)
在训练循环期间,我保存了这个模型
tf.saved_model.save(tf_agent, saved_models_path)
Run Code Online (Sandbox Code Playgroud)
训练后,我想加载保存的模型
if tf.saved_model.contains_saved_model(saved_models_path):
tf_agent = tf.saved_model.load(saved_models_path)
Run Code Online (Sandbox Code Playgroud)
此代码仅在文件夹中saved_path包含一个时才会加载保存的模型,函数contains_saved_model(saved_models_path)返回True,因此模型已加载,但出现异常并且程序崩溃:
Traceback (most recent call last):
File "/home/claudino/Projetos/dino-tf-agents/dino_ia/model/agent.py", line 50, in <module>
tf_agent = tf.saved_model.load(saved_models_path)
File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 408, in load
return load_internal(export_dir, tags)
File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 432, in load_internal
export_dir)
File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 58, in __init__
self._load_all()
File "/home/claudino/Projetos/dino-tf-agents/venv/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 168, in _load_all
slot_variable = optimizer_object.add_slot(
AttributeError: '_UserObject' object has no attribute 'add_slot'
Process finished with exit code 1
Run Code Online (Sandbox Code Playgroud)
我浏览了 tensorflow 代码,但找不到问题。任何人都可以帮助我吗?
我使用tf-agents-nightly是因为 google 的 colaboratory 源代码在tf-agents“稳定”版本上不起作用(我不确定 tf-agents 是否真的稳定),并尝试了tensorflow1.3 和的代码,2.0.0-beta0发生了同样的问题。
小智 0
您尝试过 TensorFlow 2.7 吗?这通常有助于解决这个问题。
对我有用的其他方法是以这种方式加载模型(假设模型是keras/tf.keras模型):
try:
model = tf.keras.models.load_model(model_dir)
except:
load_options = tf.saved_model.LoadOptions(experimental_io_device= '/job:localhost')
model = tf.saved_model.load(model_dir, options= load_options)
Run Code Online (Sandbox Code Playgroud)
try 子句将导致异常,因为load_model()需要一个keras_metadata.pb文件,而当您使用 .txt 文件保存模型时,该文件不存在saved_model.save()。
但是,运行该子句将以某种方式使tf.saved_model.load()运行没有任何问题。后台可能发生某种我不太理解的交互,但它对我有用,并且no attribute add_slot不会出现“ ”错误。
| 归档时间: |
|
| 查看次数: |
1153 次 |
| 最近记录: |