Hydra:从代码中访问配置文件的名称

mic*_*cio 5 python fb-hydra omegaconf

我有一个配置树,例如:

config.yaml
model/
  model_a.yaml
  model_b.yaml
  model_c.yaml

Run Code Online (Sandbox Code Playgroud)

其中config.yaml包含:

config.yaml
model/
  model_a.yaml
  model_b.yaml
  model_c.yaml

Run Code Online (Sandbox Code Playgroud)

我想从我的 python 代码或文件本身访问所使用的模型配置文件的名称(默认或覆盖)。就像是:

# @package _global_
defaults:
  - _self_
  - model: model_a.yaml

some_var: 42
Run Code Online (Sandbox Code Playgroud)

或(来自例如model_a.yaml

@hydra.main(...)
def main(config):
  model_name = config.model.__filename__
Run Code Online (Sandbox Code Playgroud)

提前致谢!

Jas*_*sha 9

查看Hydra 文档的配置 Hydra - 简介页面中提到的Hydra.runtime.choices变量。该变量存储一个映射,该映射描述 Hydra 在构成输出配置时所做的每个选择。

model: model_a.yaml在默认列表中使用上面的示例:

# my_app.py
import hydra
from pprint import pprint
from hydra.core.hydra_config import HydraConfig
from omegaconf import OmegaConf

@hydra.main(config_path=".", config_name="config")
def main(config):
    hydra_cfg = HydraConfig.get()
    print("choice of model:")
    pprint(OmegaConf.to_container(hydra_cfg.runtime.choices))

main()
Run Code Online (Sandbox Code Playgroud)

在命令行中:

$ python3 app.py
choices used:
{'hydra/callbacks': None,
 'hydra/env': 'default',
 'hydra/help': 'default',
 'hydra/hydra_help': 'default',
 'hydra/hydra_logging': 'default',
 'hydra/job_logging': 'default',
 'hydra/launcher': 'basic',
 'hydra/output': 'default',
 'hydra/sweeper': 'basic',
 'model': 'model_a.yaml'}
Run Code Online (Sandbox Code Playgroud)

正如您所看到的,在本例中,配置选项model_a.yaml存储在 Hydra 配置中,位置为hydra_cfg.runtime.choices.model