diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 1761fc0135..9180ab489c 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,16 +17,19 @@ import csv import inspect import os from argparse import Namespace -from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union from warnings import warn import torch import yaml from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, _OMEGACONF_AVAILABLE -from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities import AttributeDict, rank_zero_warn, _OMEGACONF_AVAILABLE +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.parsing import parse_class_init_keys PRIMITIVE_TYPES = (bool, int, float, str) @@ -34,6 +37,9 @@ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) if _OMEGACONF_AVAILABLE: from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig + from omegaconf.errors import UnsupportedValueType, ValidationError + # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ( @@ -321,9 +327,14 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]: """Load hparams from a file. + Args: + config_yaml: Path to config yaml file + use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, + the hparams will be converted to `DictConfig` if possible + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) @@ -338,9 +349,15 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: return {} with fs.open(config_yaml, "r") as fp: - tags = yaml.load(fp, Loader=yaml.UnsafeLoader) + hparams = yaml.load(fp, Loader=yaml.UnsafeLoader) - return tags + if _OMEGACONF_AVAILABLE: + if use_omegaconf: + try: + return OmegaConf.create(hparams) + except (UnsupportedValueType, ValidationError): + pass + return hparams def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: @@ -361,15 +378,16 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # saving with OmegaConf objects if _OMEGACONF_AVAILABLE: - if OmegaConf.is_config(hparams): - with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(hparams, fp, resolve=True) - return - for v in hparams.values(): - if OmegaConf.is_config(v): - with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True) + # deepcopy: hparams from user shouldn't be resolved + hparams = deepcopy(hparams) + to_container = partial(OmegaConf.to_container, resolve=True) + hparams = apply_to_collection(hparams, DictConfig, to_container) + with fs.open(config_yaml, "w", encoding="utf-8") as fp: + try: + OmegaConf.save(hparams, fp) return + except (UnsupportedValueType, ValidationError): + pass if not isinstance(hparams, dict): raise TypeError("hparams must be dictionary") diff --git a/pytorch_lightning/utilities/package_utils.py b/pytorch_lightning/utilities/package_utils.py new file mode 100644 index 0000000000..99fd6fcc7e --- /dev/null +++ b/pytorch_lightning/utilities/package_utils.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib + + +def _module_available(module_path: str) -> bool: + """Testing if given module is avalaible in your env + + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + # todo: find a better way than try / except + try: + mods = module_path.split('.') + assert mods, 'nothing given to test' + # it has to be tested as per partets + for i in range(len(mods)): + module_path = '.'.join(mods[:i + 1]) + if importlib.util.find_spec(module_path) is None: + return False + return True + except AttributeError: + return False diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 521dd52005..5d90583345 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -18,6 +18,8 @@ from argparse import Namespace from typing import Dict, Tuple, Union from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.package_utils import _module_available def str_to_bool_or_str(val: str) -> Union[str, bool]: @@ -115,7 +117,6 @@ def get_init_args(frame) -> dict: self_var, args_var, kwargs_var = parse_class_init_keys(cls) filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n] exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') - # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} local_args.update(local_args.get(kwargs_var, {})) diff --git a/tests/models/conf/config.yaml b/tests/models/conf/config.yaml new file mode 100644 index 0000000000..faf751c24f --- /dev/null +++ b/tests/models/conf/config.yaml @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +defaults: + - training: default + +log: ${training.log} diff --git a/tests/models/conf/training/default.yaml b/tests/models/conf/training/default.yaml new file mode 100644 index 0000000000..2c35b22365 --- /dev/null +++ b/tests/models/conf/training/default.yaml @@ -0,0 +1,2 @@ +# @package training +log: "Something" diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7081d450ee..bd8ad9d116 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -25,10 +25,13 @@ from torch.nn import functional as F from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml -from pytorch_lightning.utilities import AttributeDict, is_picklable +from pytorch_lightning.utilities import AttributeDict, is_picklable, _HYDRA_EXPERIMENTAL_AVAILABLE from tests.base import BoringModel, EvalModelTemplate, TrialMNIST +if _HYDRA_EXPERIMENTAL_AVAILABLE: + from hydra.experimental import compose, initialize class SaveHparamsModel(BoringModel): """ Tests that a model can take an object """ @@ -483,13 +486,13 @@ def test_hparams_save_yaml(tmpdir): path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml') save_hparams_to_yaml(path_yaml, hparams) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, Namespace(**hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) assert load_hparams_from_yaml(path_yaml) == hparams @@ -636,3 +639,46 @@ def test_model_with_fsspec_as_parameter(tmpdir): ) trainer.fit(model) trainer.test() + + +@pytest.mark.skipif(not HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available") +def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): + """ + This test relies on configuration saved under tests/models/conf/config.yaml + """ + + class TestHydraModel(BoringModel): + + def __init__(self, args_0, args_1, args_2, kwarg_1=None): + self.save_hyperparameters() + self.test_hparams() + config_file = f"{tmpdir}/hparams.yaml" + save_hparams_to_yaml(config_file, self.hparams) + self.hparams = load_hparams_from_yaml(config_file) + self.test_hparams() + super().__init__() + + def test_hparams(self): + assert self.hparams.args_0.log == "Something" + assert self.hparams.args_1['cfg'].log == "Something" + assert self.hparams.args_2[0].log == "Something" + assert self.hparams.kwarg_1['cfg'][0].log == "Something" + + with initialize(config_path="conf"): + args_0 = compose(config_name="config") + args_1 = {"cfg": compose(config_name="config")} + args_2 = [compose(config_name="config")] + kwarg_1 = {"cfg": [compose(config_name="config")]} + model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1) + epochs = 2 + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + limit_train_batches=10, + limit_val_batches=10, + max_epochs=epochs, + logger=False, + ) + trainer.fit(model) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)