test: save hparams to yaml (#2198)

* save hparams to yaml

* import

* resolves

* req

* Update requirements/base.txt

Co-authored-by: Omry Yadan <omry@fb.com>

Co-authored-by: Omry Yadan <omry@fb.com>
This commit is contained in:
Jirka Borovec 2020-06-16 12:34:55 +02:00 committed by GitHub
parent f94b919b96
commit e289e45120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 40 additions and 18 deletions

View File

@ -10,14 +10,14 @@ from typing import Union, Dict, Any, Optional, Callable
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.cloud_io import load as pl_load
PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
try:
from omegaconf import Container
except ImportError:
pass
Container = None
else:
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )
@ -332,11 +332,25 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
"""
Args:
config_yaml: path to new YAML file
hparams: parameters to be saved
"""
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
if Container is not None and isinstance(hparams, Container):
from omegaconf import OmegaConf
OmegaConf.save(hparams, config_yaml, resolve=True)
return
# saving the standard way
if isinstance(hparams, Namespace):
hparams = vars(hparams)
elif isinstance(hparams, AttributeDict):
hparams = dict(hparams)
assert isinstance(hparams, dict)
with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)

View File

@ -17,11 +17,6 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
try:
from omegaconf import Container
except ImportError:
Container = None
class TensorBoardLogger(LightningLoggerBase):
r"""
@ -156,14 +151,7 @@ class TensorBoardLogger(LightningLoggerBase):
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
# save the metatags file
if Container is not None:
if isinstance(self.hparams, Container):
from omegaconf import OmegaConf
OmegaConf.save(self.hparams, hparams_file, resolve=True)
else:
save_hparams_to_yaml(hparams_file, self.hparams)
else:
save_hparams_to_yaml(hparams_file, self.hparams)
save_hparams_to_yaml(hparams_file, self.hparams)
@rank_zero_only
def finalize(self, status: str) -> None:

View File

@ -101,7 +101,7 @@ from pytorch_lightning.overrides.data_parallel import (
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.cloud_io import load as pl_load
try:
import torch_xla

View File

@ -5,4 +5,5 @@ tqdm>=4.41.0
torch>=1.3
tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
pyyaml>=3.13
# pyyaml>=3.13
PyYAML>=5.1 # OmegaConf requirement

View File

@ -8,6 +8,7 @@ import torch
from omegaconf import OmegaConf, Container
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict
from tests.base import EvalModelTemplate
@ -407,3 +408,21 @@ def test_hparams_pickle(tmpdir):
assert ad == pickle.loads(pkl)
pkl = cloudpickle.dumps(ad)
assert ad == pickle.loads(pkl)
def test_hparams_save_yaml(tmpdir):
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
nasted=dict(any_num=123, anystr='abcd'))
path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')
save_hparams_to_yaml(path_yaml, hparams)
assert load_hparams_from_yaml(path_yaml) == hparams
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams

View File

@ -19,7 +19,7 @@ from pytorch_lightning.core.saving import (
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.cloud_io import load as pl_load
from tests.base import EvalModelTemplate