我正在使用 Hydra 来训练机器学习模型。它非常适合执行复杂的命令,例如python train.py data=MNIST batch_size=64 loss=l2. 但是,如果我想使用相同的参数运行经过训练的模型,我必须执行类似的操作python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml。然后,我使用argparse加载前面的 yaml 并使用 compose API 来初始化 Hydra 环境。训练模型的路径是从 Hydra 文件的路径推断出来的.yaml。如果我想修改其中一个参数,我必须添加其他argparse参数并运行类似python reconstruct.py --config_file path_to_previous_job/.hydra/config.yaml --batch_size 128. 然后,代码手动使用命令行上指定的参数覆盖任何 Hydra 参数。
这样做的正确方法是什么?
我当前的代码如下所示:
train.py:
import hydra
@hydra.main(config_name="config", config_path="conf")
def main(cfg):
# [training code using cfg.data, cfg.batch_size, cfg.loss etc.]
# [code outputs model checkpoint to job folder generated by Hydra]
main()
Run Code Online (Sandbox Code Playgroud)
reconstruct.py:
import argparse
import os
from hydra.experimental import initialize, compose
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('hydra_config')
parser.add_argument('--batch_size', type=int)
# [other flags and parameters I may need to override]
args = parser.parse_args()
# Create the Hydra environment.
initialize()
cfg = compose(config_name=args.hydra_config)
# Since checkpoints are stored next to the .hydra, we manually generate the path.
checkpoint_dir = os.path.dirname(os.path.dirname(args.hydra_config))
# Manually override any parameters which can be changed on the command line.
batch_size = args.batch_size if args.batch_size else cfg.data.batch_size
# [code which uses checkpoint_dir to load the model]
# [code which uses both batch_size and params in cfg to set up the data etc.]
Run Code Online (Sandbox Code Playgroud)
这是我第一次发帖,所以如果我需要澄清任何事情,请告诉我。
如果您想按原样加载以前的配置而不更改它,请使用OmegaConf.load(file_path).
如果您想重新编写配置(听起来像您所做的那样,因为您添加了您想要覆盖的内容),我建议您使用 Compose API 并从作业输出目录中的覆盖文件中传递参数(下一步到存储的 config.yaml),但连接当前运行参数。
这个脚本似乎正在完成这项工作:
import os
from dataclasses import dataclass
from os.path import join
from typing import Optional
from omegaconf import OmegaConf
import hydra
from hydra import compose
from hydra.core.config_store import ConfigStore
from hydra.core.hydra_config import HydraConfig
from hydra.utils import to_absolute_path
# You can also use a yaml config file instead of this Structured Config
@dataclass
class Config:
load_checkpoint: Optional[str] = None
batch_size: int = 16
loss: str = "l2"
cs = ConfigStore.instance()
cs.store(name="config", node=Config)
@hydra.main(config_path=".", config_name="config")
def my_app(cfg: Config) -> None:
if cfg.load_checkpoint is not None:
output_dir = to_absolute_path(cfg.load_checkpoint)
original_overrides = OmegaConf.load(join(output_dir, ".hydra/overrides.yaml"))
current_overrides = HydraConfig.get().overrides.task
hydra_config = OmegaConf.load(join(output_dir, ".hydra/hydra.yaml"))
# getting the config name from the previous job.
config_name = hydra_config.hydra.job.config_name
# concatenating the original overrides with the current overrides
overrides = original_overrides + current_overrides
# compose a new config from scratch
cfg = compose(config_name, overrides=overrides)
# train
print("Running in ", os.getcwd())
print(OmegaConf.to_yaml(cfg))
if __name__ == "__main__":
my_app()
Run Code Online (Sandbox Code Playgroud)
~/tmp$ python train.py
Running in /home/omry/tmp/outputs/2021-04-19/21-23-13
load_checkpoint: null
batch_size: 16
loss: l2
~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13
Running in /home/omry/tmp/outputs/2021-04-19/21-23-22
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 16
loss: l2
~/tmp$ python train.py load_checkpoint=/home/omry/tmp/outputs/2021-04-19/21-23-13 batch_size=32
Running in /home/omry/tmp/outputs/2021-04-19/21-23-28
load_checkpoint: /home/omry/tmp/outputs/2021-04-19/21-23-13
batch_size: 32
loss: l2
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2589 次 |
| 最近记录: |