Pur*_*nth 11 python serialization pickle deep-learning pytorch
我试图运行一个名为api.py. 在此文件中,我正在加载使用 PyTorch 构建和训练的深度学习模型的 pickle 文件。
api.py \napi.py下面给出的函数是最重要的函数。
def load_model_weights(model_architecture, weights_path):\n if os.path.isfile(weights_path):\n cherrypy.log("CHERRYPYLOG Loading model from: {}".format(weights_path))\n model_architecture.load_state_dict(torch.load(weights_path))\n else:\n raise ValueError("Path not found {}".format(weights_path))\n\n \ndef load_recommender(vector_dim, hidden, activation, dropout, weights_path):\n\n rencoder_api = model.AutoEncoder(layer_sizes=[vector_dim] + [int(l) for l in hidden.split(\',\')],\n nl_type=activation,\n is_constrained=False,\n dp_drop_prob=dropout,\n last_layer_activations=False)\n load_model_weights(rencoder_api, weights_path) \n rencoder_api.eval()\n rencoder_api = rencoder_api.cuda()\n return rencoder_api\nRun Code Online (Sandbox Code Playgroud)\n目录结构
\ndef load_model_weights(model_architecture, weights_path):\n if os.path.isfile(weights_path):\n cherrypy.log("CHERRYPYLOG Loading model from: {}".format(weights_path))\n model_architecture.load_state_dict(torch.load(weights_path))\n else:\n raise ValueError("Path not found {}".format(weights_path))\n\n \ndef load_recommender(vector_dim, hidden, activation, dropout, weights_path):\n\n rencoder_api = model.AutoEncoder(layer_sizes=[vector_dim] + [int(l) for l in hidden.split(\',\')],\n nl_type=activation,\n is_constrained=False,\n dp_drop_prob=dropout,\n last_layer_activations=False)\n load_model_weights(rencoder_api, weights_path) \n rencoder_api.eval()\n rencoder_api = rencoder_api.cuda()\n return rencoder_api\nRun Code Online (Sandbox Code Playgroud)\n我收到这样的错误(serialization.py)。有人可以帮我解决这个错误吗?
\nD:\\Anaconda\\envs\\practise\\lib\\site-packages\\torch\\serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)\n 762 "functionality.")\n 763 \n--> 764 magic_number = pickle_module.load(f, **pickle_load_args)\n 765 if magic_number != MAGIC_NUMBER:\n 766 raise RuntimeError("Invalid magic number; corrupt file?")\n\nUnpicklingError: A load persistent id instruction was encountered,\nbut no persistent_load function was specified.\nRun Code Online (Sandbox Code Playgroud)\n
在搜索 PyTorch 文档后,我最终将模型保存为ONNX格式,然后将该 ONNX 模型加载到 PyTorch 模型中并使用它进行推理。
import onnx
from onnx2pytorch import ConvertModel
def load_model_weights(model_architecture, weights_path):
if os.path.isfile("model.onnx"):
cherrypy.log("CHERRYPYLOG Loading model from: {}".format(weights_path))
onnx_model = onnx.load("model.onnx")
pytorch_model = ConvertModel(onnx_model)
## model_architecture.load_state_dict(torch.load(weights_path))
else:
raise ValueError("Path not found {}".format(weights_path))
def load_recommender(vector_dim, hidden, activation, dropout, weights_path):
rencoder_api = model.AutoEncoder(layer_sizes=[vector_dim] + [int(l) for l in hidden.split(',')],
nl_type=activation,
is_constrained=False,
dp_drop_prob=dropout,
last_layer_activations=False)
load_model_weights(rencoder_api, weights_path)
rencoder_api.eval()
rencoder_api = rencoder_api.cuda()
return rencoder_api
Run Code Online (Sandbox Code Playgroud)
一些有用的资源:
| 归档时间: |
|
| 查看次数: |
28462 次 |
| 最近记录: |