在C++中将pytorch张量转换为opencv mat,反之亦然

bas*_*low 7 c++ memory opencv pytorch

我想在 C++ 中将 pytorch 张量转换为 opencv mat,反之亦然。我有这两个功能:

cv::Mat TensorToCVMat(torch::Tensor tensor)
{
    std::cout << "converting tensor to cvmat\n";
    tensor = tensor.squeeze().detach().permute({1, 2, 0});
    tensor = tensor.mul(255).clamp(0, 255).to(torch::kU8);
    tensor = tensor.to(torch::kCPU);
    int64_t height = tensor.size(0);
    int64_t width = tensor.size(1);
    cv::Mat mat(width, height, CV_8UC3);
    std::memcpy((void *)mat.data, tensor.data_ptr(), sizeof(torch::kU8) * tensor.numel());
    return mat.clone();
}

torch::Tensor CVMatToTensor(cv::Mat mat)
{
    std::cout << "converting cvmat to tensor\n";
    cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB);
    cv::Mat matFloat;
    mat.convertTo(matFloat, CV_32F, 1.0 / 255);
    auto size = matFloat.size();
    auto nChannels = matFloat.channels();
    auto tensor = torch::from_blob(matFloat.data, {1, size.height, size.width, nChannels});
    return tensor.permute({0, 3, 1, 2});
}
Run Code Online (Sandbox Code Playgroud)

在我的代码中,我加载了两个图像(image1image2),我想将它们转换为 pytorch 张量,然后返回到 opencv mat 以检查它是否有效。问题是我在第一次调用时遇到内存访问错误TensorToCVMat,我无法弄清楚出了什么问题,因为我对 C++ 编程没有太多经验。

cv::Mat image1;
image1 = cv::imread(argv[1]);
if (!image1.data)
{
    std::cout << "no image data\n";
    return -1;
}
cv::Mat image2;
image2 = cv::imread(argv[2]);
if (!image2.data)
{
    std::cout << "no image data\n";
    return -1;
}

torch::Tensor tensor1 = CVMatToTensor(image1);
cv::Mat new_image1 = TensorToCVMat(tensor1); // <<< this is where the memory access error is thrown
torch::Tensor tensor2 = CVMatToTensor(image2);
cv::Mat new_image2 = TensorToCVMat(tensor2);
Run Code Online (Sandbox Code Playgroud)

如果你能给我提示或解释来解决这个问题,那就太好了。

Sar*_*ran 4

不确定错误是否发生在 memcpy 步骤。但您可以使用Mat构造void* data函数的变体

Mat (int rows, int cols, int type, void *data, size_t step=AUTO_STEP)
Run Code Online (Sandbox Code Playgroud)

你可以跳过 memcpy 步骤

tensor = uint8_tensor //shape: (h, w, 3)
cv::Mat mat = cv::Mat(height, width, CV_8UC3, tensor.data_ptr());
return mat;
Run Code Online (Sandbox Code Playgroud)