zha*_*itc 5 c++ templates pytorch
我正在尝试使用 C++ Tensor API 为 PyTorch 编写 C++/CUDA 扩展,并且我希望我的代码能够使用 float32 和 float16 (半精度)。我不确定如何访问来自 Python 的半张量的数据指针。
以下是我对浮点张量的处理方法:
// Access data pointer for float Tensor A
torch::Tensor A;
float* ptr = A.data<float>();
Run Code Online (Sandbox Code Playgroud)
这是我对半张量所做的尝试:
// CUDA float 16 type
// undefined symbol: _ZNK2at6Tensor4dataI6__halfEEPT_v
A.data<__half>();
// PyTorch float16 type
// error: no instance of function template "at::Tensor::data"
A.data<torch::ScalarType::Half>();
// Casting to __half*
// This compiles but throws and error if the requested pointer type doesn't match the Tensor type:
// RuntimeError: expected scalar type Float but found Half
(__half*)(A.data<float>());
Run Code Online (Sandbox Code Playgroud)
我尝试查看 C++ api 源代码,但找不到任何其他看起来像 float16 类型的内容。
系统信息:Python 3.6.2 PyTorch 1.0.1
归档时间: |
|
查看次数: |
1718 次 |
最近记录: |