bugfix: Resolve interpolation bug with Hydra ()

* resolve bug

* Apply suggestions from code review

* resolve package import

* resolve import

* update on comments

* update on comments

* hacky fix

* update

* exit

* update

* to_container

* typo

* resolve import

* update

* resolve pep8

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

(cherry picked from commit bb5031b3bf)
This commit is contained in:
chaton 2021-01-09 13:55:55 +00:00 committed by Jirka Borovec
parent 127e04124d
commit 8e75f2cde0
6 changed files with 139 additions and 19 deletions
pytorch_lightning
tests/models

View File

@ -17,16 +17,19 @@ import csv
import inspect
import os
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
from warnings import warn
import torch
import yaml
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities import AttributeDict, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.parsing import parse_class_init_keys
PRIMITIVE_TYPES = (bool, int, float, str)
@ -34,6 +37,9 @@ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from omegaconf.errors import UnsupportedValueType, ValidationError
# the older shall be on the top
CHECKPOINT_PAST_HPARAMS_KEYS = (
@ -321,9 +327,14 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) ->
writer.writerow({"key": k, "value": v})
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]:
"""Load hparams from a file.
Args:
config_yaml: Path to config yaml file
use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True,
the hparams will be converted to `DictConfig` if possible
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_yaml = './testing-hparams.yaml'
>>> save_hparams_to_yaml(path_yaml, hparams)
@ -338,9 +349,15 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
return {}
with fs.open(config_yaml, "r") as fp:
tags = yaml.load(fp, Loader=yaml.UnsafeLoader)
hparams = yaml.load(fp, Loader=yaml.UnsafeLoader)
return tags
if _OMEGACONF_AVAILABLE:
if use_omegaconf:
try:
return OmegaConf.create(hparams)
except (UnsupportedValueType, ValidationError):
pass
return hparams
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
@ -361,15 +378,16 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
# saving with OmegaConf objects
if _OMEGACONF_AVAILABLE:
if OmegaConf.is_config(hparams):
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
OmegaConf.save(hparams, fp, resolve=True)
return
for v in hparams.values():
if OmegaConf.is_config(v):
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True)
# deepcopy: hparams from user shouldn't be resolved
hparams = deepcopy(hparams)
to_container = partial(OmegaConf.to_container, resolve=True)
hparams = apply_to_collection(hparams, DictConfig, to_container)
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
try:
OmegaConf.save(hparams, fp)
return
except (UnsupportedValueType, ValidationError):
pass
if not isinstance(hparams, dict):
raise TypeError("hparams must be dictionary")

View File

@ -0,0 +1,36 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
def _module_available(module_path: str) -> bool:
"""Testing if given module is avalaible in your env
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
# todo: find a better way than try / except
try:
mods = module_path.split('.')
assert mods, 'nothing given to test'
# it has to be tested as per partets
for i in range(len(mods)):
module_path = '.'.join(mods[:i + 1])
if importlib.util.find_spec(module_path) is None:
return False
return True
except AttributeError:
return False

View File

@ -18,6 +18,8 @@ from argparse import Namespace
from typing import Dict, Tuple, Union
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.package_utils import _module_available
def str_to_bool_or_str(val: str) -> Union[str, bool]:
@ -115,7 +117,6 @@ def get_init_args(frame) -> dict:
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n]
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_var, {}))

View File

@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
defaults:
- training: default
log: ${training.log}

View File

@ -0,0 +1,2 @@
# @package training
log: "Something"

View File

@ -25,10 +25,13 @@ from torch.nn import functional as F
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.utilities import AttributeDict, is_picklable
from pytorch_lightning.utilities import AttributeDict, is_picklable, _HYDRA_EXPERIMENTAL_AVAILABLE
from tests.base import BoringModel, EvalModelTemplate, TrialMNIST
if _HYDRA_EXPERIMENTAL_AVAILABLE:
from hydra.experimental import compose, initialize
class SaveHparamsModel(BoringModel):
""" Tests that a model can take an object """
@ -483,13 +486,13 @@ def test_hparams_save_yaml(tmpdir):
path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')
save_hparams_to_yaml(path_yaml, hparams)
assert load_hparams_from_yaml(path_yaml) == hparams
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
@ -636,3 +639,46 @@ def test_model_with_fsspec_as_parameter(tmpdir):
)
trainer.fit(model)
trainer.test()
@pytest.mark.skipif(not HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available")
def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
"""
This test relies on configuration saved under tests/models/conf/config.yaml
"""
class TestHydraModel(BoringModel):
def __init__(self, args_0, args_1, args_2, kwarg_1=None):
self.save_hyperparameters()
self.test_hparams()
config_file = f"{tmpdir}/hparams.yaml"
save_hparams_to_yaml(config_file, self.hparams)
self.hparams = load_hparams_from_yaml(config_file)
self.test_hparams()
super().__init__()
def test_hparams(self):
assert self.hparams.args_0.log == "Something"
assert self.hparams.args_1['cfg'].log == "Something"
assert self.hparams.args_2[0].log == "Something"
assert self.hparams.kwarg_1['cfg'][0].log == "Something"
with initialize(config_path="conf"):
args_0 = compose(config_name="config")
args_1 = {"cfg": compose(config_name="config")}
args_2 = [compose(config_name="config")]
kwarg_1 = {"cfg": [compose(config_name="config")]}
model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1)
epochs = 2
checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
limit_train_batches=10,
limit_val_batches=10,
max_epochs=epochs,
logger=False,
)
trainer.fit(model)
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)