Mer*_*ury 6 jit pytorch torchscript libtorch
我目前正在编写一个 C++ 程序,需要对 torchScript 格式的 CNN 模型的结构进行一些分析。我按照 torch.org 上显示的方式使用 C++ torch 库,加载到模型中,如下所示:
#include <torch/script.h>
#include <torch/torch.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
return 0;
}
Run Code Online (Sandbox Code Playgroud)
据我所知,module
由一组嵌套的集合组成,torch::jit::script::Module
其中最低的代表内置函数。我按如下方式访问这些最低模块:
void print_modules(const torch::jit::script::Module& imodule) {
for (const auto& module : imodule.named_children()) {
if(module.value.children().size() > 0){
print_modules(module.value);
}
else{
std::cout << module.name << "\n";
}
}
}
Run Code Online (Sandbox Code Playgroud)
该函数递归地遍历模块并打印最低级别的名称,这些名称对应于 torch 脚本的内置函数。
我现在的问题是,如何访问这些内置函数的详细信息,例如卷积的步幅长度?
我一生都无法弄清楚如何访问模块的这些基本属性。
归档时间: |
|
查看次数: |
923 次 |
最近记录: |