在哪里可以找到 torch.unique() 的源代码?

Bri*_*nto 3 pytorch

我只能在pytorch源代码(https://github.com/pytorch/pytorch/blob/2367face24afb159f73ebf40dc6f23e46132b770/torch/tical.py#L783)中找到以下函数调用:

_VF.unique_dim()torch._unique2()

但它们没有指向目录中的其他任何地方

jod*_*dag 6

大多数 pytorch 后端代码都是用 C++ 和/或 CUDA 实现的。要查看它,您需要在源代码中找到适当的入口点。有几种方法可以做到这一点,但我发现最简单的方法是在 github 上搜索关键字,而无需自己下载所有代码。

例如,如果您访问github.com并搜索unique_dim repo:pytorch/pytorch,然后单击左侧的“代码”选项卡,您应该很快找到以下内容。

来自torch/jit/_builtins.py:103

 17: _builtin_ops = [
...
103:    (torch._VF.unique_dim, "aten::unique_dim"),
Run Code Online (Sandbox Code Playgroud)

通过对代码的进一步分析,我们可以得出结论,torch._VF.unique_dim实际上是调用aten::unique_dimATen 库中的函数。

与ATen中的大多数函数一样,该函数有多种实现。大多数 ATen 函数都注册在aten/src/ATen/native/native_functions.yaml中,一般这里的函数都会有_cpu_cuda版本。

回到搜索结果我们可以发现CUDA实现实际上调用的是unique_dim_cudaaten /src/ATen/native/cuda/Unique.cu:197处的函数

196: std::tuple<Tensor, Tensor, Tensor>
197: unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
198:   return AT_DISPATCH_ALL_TYPES_AND2(kBool, kHalf, self.scalar_type(), "unique_dim", [&] {
199:     return unique_dim_cuda_template<scalar_t>(self, dim, false, return_inverse, return_counts);
200:   });
201: }
Run Code Online (Sandbox Code Playgroud)

CPU 实现正在调用unique_dim_cpuaten /src/ATen/native/Unique.cpp:271 处的函数

270: std::tuple<Tensor, Tensor, Tensor>
271: unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse, const bool return_counts) {
272:   return AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "unique_dim", [&] {
273:     // The current implementation using `dim` always sorts due to unhashable tensors
274:     return _unique_dim_cpu_template<scalar_t>(self, dim, false, return_inverse, return_counts);
275:   });
276: }
Run Code Online (Sandbox Code Playgroud)

从这一点开始,您应该能够进一步跟踪函数调用,以准确了解它们在做什么。

经过类似的搜索字符串,您应该发现分别针对 CUDA 和 CPU在aten/src/ATen/native/cuda/Unique.cu:188aten/src/ATen/native/Unique.cpp:264torch._unique2实现。