diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 28501fdcda..dea3fa99dd 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -16,11 +16,9 @@ from pytorch_lightning.utilities.cloud_io import gfile, cloud_open PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) try: - from omegaconf import Container + from omegaconf import OmegaConf except ImportError: - OMEGACONF_AVAILABLE = False -else: - OMEGACONF_AVAILABLE = True + OmegaConf = None # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ( @@ -330,23 +328,27 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: if not gfile.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") - if OMEGACONF_AVAILABLE and isinstance(hparams, Container): - from omegaconf import OmegaConf - - OmegaConf.save(hparams, config_yaml, resolve=True) - return - - # saving the standard way + # convert Namespace or AD to dict if isinstance(hparams, Namespace): hparams = vars(hparams) elif isinstance(hparams, AttributeDict): hparams = dict(hparams) + + # saving with OmegaConf objects + if OmegaConf is not None: + if OmegaConf.is_config(hparams): + OmegaConf.save(hparams, config_yaml, resolve=True) + return + for v in hparams.values(): + if OmegaConf.is_config(v): + OmegaConf.save(OmegaConf.create(hparams), config_yaml, resolve=True) + return + + # saving the standard way assert isinstance(hparams, dict) - - with cloud_open(config_yaml, "w", newline="") as fp: + with open(config_yaml, 'w', newline='') as fp: yaml.dump(hparams, fp) - def convert(val: str) -> Union[int, float, bool, str]: try: return ast.literal_eval(val)