在张量流中我注册一个操作,如下所示:
REGISTER_OP("RimeBSqrt")
.Input("stokes: FT")
.Input("alpha: FT")
.Input("frequency: FT")
.Input("ref_freq: FT")
.Output("b_sqrt: CT")
.Attr("FT: {float, double} = DT_FLOAT")
.Attr("CT: {complex64, complex128} = DT_COMPLEX64");
Run Code Online (Sandbox Code Playgroud)
上述所有输入都是张量,但 ref_freq 是标量或 0 维张量。在 CPU 内核的 Compute() 方法中,我可以执行以下操作来提取标量:
const Tensor & in_ref_freq = context->input(3);
FT ref_freq = in_ref_freq.tensor<FT, 1>()(0);
Run Code Online (Sandbox Code Playgroud)
然而,相同类型的代码会在我的 GPU 内核的 Compute() 方法中生成段错误,因为 CPU 现在尝试访问 GPU 设备上的内存块。在将这个标量值发送到 GPU 之前是否有办法拦截它?我想避免 CUDA 内核中出现以下额外级别的内存间接寻址:
template <typename FT>
__global__ void kernel(..., FT * ref_freq, ...)
{
FT value = ref_freq[0];
}
Run Code Online (Sandbox Code Playgroud)
我不认为Attr这是使用的方法,ref_freq因为它是可变的、可配置的值。
您可以指定 TensorFlow 的一个或多个输入(或输出)OpKernel位于“主机内存”中,这样您就可以访问方法中的值Compute()。为此,您需要修改您的REGISTER_KERNEL_BUILDER()调用以添加.HostMemory("ref_freq")指令:
REGISTER_KERNEL_BUILDER(
Name("RimeBSqrt")
.Device(tensorflow::DEVICE_GPU)
.TypeConstraint<float>("FT")
.TypeConstraint<tensorflow::complex64>("CT")
.HostMemory("ref_freq"),
RimeBSqrt<tensorflow::GPUDevice, float, tensorflow::complex64>);
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
877 次 |
| 最近记录: |