libtorch (PyTorch C++) 奇怪的类语法

Jac*_*own 5 c++ pytorch libtorch

在官方PyTorch C ++在GitHub上的例子在这里 你可以看到一个类的奇怪定义:

class CustomDataset : public torch::data::datasets::Dataset<CustomDataset> {...}
Run Code Online (Sandbox Code Playgroud)

我的理解是,这定义了一个CustomDataset“继承自”或“扩展”的类torch::data::datasets::Dataset<CustomDataset>。这对我来说很奇怪,因为我们正在创建的类是从另一个类继承的,该类由我们正在创建的类进行参数化......这甚至是如何工作的?这是什么意思?在我看来,这就像一个Integer继承自的类vector<Integer>,这似乎很荒谬。

ζ--*_*ζ-- 8

这是奇怪重复出现的模板模式,简称 CRTP。这种技术的一个主要优点是它启用了所谓的静态多态性,这意味着 中的函数torch::data::datasets::Dataset可以调用 的函数CustomDataset,而无需使这些函数成为虚拟的(从而处理虚拟方法调度的运行时混乱等)。您还可以enable_if根据自定义数据集类型的属性执行编译时元编程,例如 compile-time s。

在 PyTorch 的情况下,BaseDataset( 的超类Dataset)大量使用这种技术来支持映射和过滤等操作:

  template <typename TransformType>
  MapDataset<Self, TransformType> map(TransformType transform) & {
    return datasets::map(static_cast<Self&>(*this), std::move(transform));
  }
Run Code Online (Sandbox Code Playgroud)

注意对this派生类型的静态转换(只要正确应用了 CRTP 就是合法的);datasets::map构造一个MapDataset对象,该对象也由数据集类型参数化,允许MapDataset实现静态调用方法get_batch(或遇到编译时错误,如果它们不存在)。

此外,由于MapDataset接收自定义数据集类型作为类型参数,编译时元编程是可能的:

  /// The implementation of `get_batch()` for the stateless case, which simply
  /// applies the transform to the output of `get_batch()` from the dataset.
  template <
      typename D = SourceDataset,
      typename = torch::disable_if_t<D::is_stateful>>
  OutputBatchType get_batch_impl(BatchRequestType indices) {
    return transform_.apply_batch(dataset_.get_batch(std::move(indices)));
  }

  /// The implementation of `get_batch()` for the stateful case. Here, we follow
  /// the semantics of `Optional.map()` in many functional languages, which
  /// applies a transformation to the optional's content when the optional
  /// contains a value, and returns a new optional (of a different type)  if the
  /// original optional returned by `get_batch()` was empty.
  template <typename D = SourceDataset>
  torch::enable_if_t<D::is_stateful, OutputBatchType> get_batch_impl(
      BatchRequestType indices) {
    if (auto batch = dataset_.get_batch(std::move(indices))) {
      return transform_.apply_batch(std::move(*batch));
    }
    return nullopt;
  }
Run Code Online (Sandbox Code Playgroud)

请注意,条件启用取决于SourceDataset,我们之所以可用,是因为数据集使用此 CRTP 模式进行了参数化。

  • 太感谢了!我一整天都在盯着这个问题,想知道这是不是一个愚蠢的问题。对于来自 Python 的我来说,C++ 太令人困惑了。我希望能够编写快速的代码,但我也希望能够看到代码中的美丽。我想我对 C++ 还没有完全掌握。 (2认同)
  • 即使我不完全理解你所说的一切,我现在至少有一些想法并且有一个可以谷歌的关键字。有时这是最困难的部分... (2认同)