bugfix: Resolve interpolation bug with Hydra (#5406)
* resolve bug
* Apply suggestions from code review
* resolve package import
* resolve import
* update on comments
* update on comments
* hacky fix
* update
* exit
* update
* to_container
* typo
* resolve import
* update
* resolve pep8
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
(cherry picked from commit bb5031b3bf
)
This commit is contained in:
parent
127e04124d
commit
8e75f2cde0
pytorch_lightning
tests/models
|
@ -17,16 +17,19 @@ import csv
|
|||
import inspect
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities import AttributeDict, rank_zero_warn, _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.parsing import parse_class_init_keys
|
||||
|
||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
|
@ -34,6 +37,9 @@ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
|||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from omegaconf.errors import UnsupportedValueType, ValidationError
|
||||
|
||||
|
||||
# the older shall be on the top
|
||||
CHECKPOINT_PAST_HPARAMS_KEYS = (
|
||||
|
@ -321,9 +327,14 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) ->
|
|||
writer.writerow({"key": k, "value": v})
|
||||
|
||||
|
||||
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
|
||||
def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]:
|
||||
"""Load hparams from a file.
|
||||
|
||||
Args:
|
||||
config_yaml: Path to config yaml file
|
||||
use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True,
|
||||
the hparams will be converted to `DictConfig` if possible
|
||||
|
||||
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
|
||||
>>> path_yaml = './testing-hparams.yaml'
|
||||
>>> save_hparams_to_yaml(path_yaml, hparams)
|
||||
|
@ -338,9 +349,15 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
|
|||
return {}
|
||||
|
||||
with fs.open(config_yaml, "r") as fp:
|
||||
tags = yaml.load(fp, Loader=yaml.UnsafeLoader)
|
||||
hparams = yaml.load(fp, Loader=yaml.UnsafeLoader)
|
||||
|
||||
return tags
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
if use_omegaconf:
|
||||
try:
|
||||
return OmegaConf.create(hparams)
|
||||
except (UnsupportedValueType, ValidationError):
|
||||
pass
|
||||
return hparams
|
||||
|
||||
|
||||
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
||||
|
@ -361,15 +378,16 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
|||
|
||||
# saving with OmegaConf objects
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
if OmegaConf.is_config(hparams):
|
||||
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
|
||||
OmegaConf.save(hparams, fp, resolve=True)
|
||||
return
|
||||
for v in hparams.values():
|
||||
if OmegaConf.is_config(v):
|
||||
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
|
||||
OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True)
|
||||
# deepcopy: hparams from user shouldn't be resolved
|
||||
hparams = deepcopy(hparams)
|
||||
to_container = partial(OmegaConf.to_container, resolve=True)
|
||||
hparams = apply_to_collection(hparams, DictConfig, to_container)
|
||||
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
|
||||
try:
|
||||
OmegaConf.save(hparams, fp)
|
||||
return
|
||||
except (UnsupportedValueType, ValidationError):
|
||||
pass
|
||||
|
||||
if not isinstance(hparams, dict):
|
||||
raise TypeError("hparams must be dictionary")
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
|
||||
|
||||
def _module_available(module_path: str) -> bool:
|
||||
"""Testing if given module is avalaible in your env
|
||||
|
||||
>>> _module_available('os')
|
||||
True
|
||||
>>> _module_available('bla.bla')
|
||||
False
|
||||
"""
|
||||
# todo: find a better way than try / except
|
||||
try:
|
||||
mods = module_path.split('.')
|
||||
assert mods, 'nothing given to test'
|
||||
# it has to be tested as per partets
|
||||
for i in range(len(mods)):
|
||||
module_path = '.'.join(mods[:i + 1])
|
||||
if importlib.util.find_spec(module_path) is None:
|
||||
return False
|
||||
return True
|
||||
except AttributeError:
|
||||
return False
|
|
@ -18,6 +18,8 @@ from argparse import Namespace
|
|||
from typing import Dict, Tuple, Union
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.package_utils import _module_available
|
||||
|
||||
|
||||
def str_to_bool_or_str(val: str) -> Union[str, bool]:
|
||||
|
@ -115,7 +117,6 @@ def get_init_args(frame) -> dict:
|
|||
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
|
||||
filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n]
|
||||
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
|
||||
|
||||
# only collect variables that appear in the signature
|
||||
local_args = {k: local_vars[k] for k in init_parameters.keys()}
|
||||
local_args.update(local_args.get(kwargs_var, {}))
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
defaults:
|
||||
- training: default
|
||||
|
||||
log: ${training.log}
|
|
@ -0,0 +1,2 @@
|
|||
# @package training
|
||||
log: "Something"
|
|
@ -25,10 +25,13 @@ from torch.nn import functional as F
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
|
||||
from pytorch_lightning.utilities import AttributeDict, is_picklable
|
||||
from pytorch_lightning.utilities import AttributeDict, is_picklable, _HYDRA_EXPERIMENTAL_AVAILABLE
|
||||
from tests.base import BoringModel, EvalModelTemplate, TrialMNIST
|
||||
|
||||
if _HYDRA_EXPERIMENTAL_AVAILABLE:
|
||||
from hydra.experimental import compose, initialize
|
||||
|
||||
class SaveHparamsModel(BoringModel):
|
||||
""" Tests that a model can take an object """
|
||||
|
@ -483,13 +486,13 @@ def test_hparams_save_yaml(tmpdir):
|
|||
path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')
|
||||
|
||||
save_hparams_to_yaml(path_yaml, hparams)
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
|
||||
|
||||
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
|
||||
|
||||
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams
|
||||
|
||||
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
|
@ -636,3 +639,46 @@ def test_model_with_fsspec_as_parameter(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
trainer.test()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available")
|
||||
def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
|
||||
"""
|
||||
This test relies on configuration saved under tests/models/conf/config.yaml
|
||||
"""
|
||||
|
||||
class TestHydraModel(BoringModel):
|
||||
|
||||
def __init__(self, args_0, args_1, args_2, kwarg_1=None):
|
||||
self.save_hyperparameters()
|
||||
self.test_hparams()
|
||||
config_file = f"{tmpdir}/hparams.yaml"
|
||||
save_hparams_to_yaml(config_file, self.hparams)
|
||||
self.hparams = load_hparams_from_yaml(config_file)
|
||||
self.test_hparams()
|
||||
super().__init__()
|
||||
|
||||
def test_hparams(self):
|
||||
assert self.hparams.args_0.log == "Something"
|
||||
assert self.hparams.args_1['cfg'].log == "Something"
|
||||
assert self.hparams.args_2[0].log == "Something"
|
||||
assert self.hparams.kwarg_1['cfg'][0].log == "Something"
|
||||
|
||||
with initialize(config_path="conf"):
|
||||
args_0 = compose(config_name="config")
|
||||
args_1 = {"cfg": compose(config_name="config")}
|
||||
args_2 = [compose(config_name="config")]
|
||||
kwarg_1 = {"cfg": [compose(config_name="config")]}
|
||||
model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1)
|
||||
epochs = 2
|
||||
checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[checkpoint_callback],
|
||||
limit_train_batches=10,
|
||||
limit_val_batches=10,
|
||||
max_epochs=epochs,
|
||||
logger=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)
|
||||
|
|
Loading…
Reference in New Issue