如何在 Hydra 中使用 OmegaConf 自定义插值

ant*_*ec2 2 fb-hydra

如何在Hydra 中使用 OmegaConf自定义插值

一些背景:可以为平方根定义自定义插值:

from omegaconf import OmegaConf
import math
OmegaConf.register_resolver("sqrt", lambda x: math.sqrt(float(x)))
Run Code Online (Sandbox Code Playgroud)

并将其与此 config.yaml 一起使用:

foo: ${sqrt:9}
Run Code Online (Sandbox Code Playgroud)

加载和打印 foo:

cfg = OmegaConf.load('config.yaml')
print(cfg.foo)
Run Code Online (Sandbox Code Playgroud)

输出 3.0

使用 Hydra 尝试此操作时:

import hydra

@hydra.main(config_path="config.yaml")
def main(cfg):
  print(cfg.foo)

if __name__ == "__main__":
  main()
Run Code Online (Sandbox Code Playgroud)

我收到以下错误:

Unsupported interpolation type sqrt
    full_key: foo
    reference_type=Optional[Dict[Any, Any]]
    object_type=dict
Run Code Online (Sandbox Code Playgroud)

使用 Hydra 时如何注册我的解析器?

Omr*_*dan 10

您可以提前注册您的自定义解析器:

配置文件:

foo: ${sqrt:9}
Run Code Online (Sandbox Code Playgroud)

主要.py:

from omegaconf import OmegaConf
import math
import hydra

OmegaConf.register_new_resolver("sqrt", lambda x: math.sqrt(float(x)))

@hydra.main(config_path=".", config_name="config")
def main(cfg):
  print(cfg.foo)

if __name__ == "__main__":
  main()
Run Code Online (Sandbox Code Playgroud)

这将打印 3.0。

这种方法也适用于Compose API。当您访问节点(懒惰)时,正在对自定义解析器进行评估。您只需要在访问之前注册解析器。