test: save hparams to yaml (#2198)
* save hparams to yaml * import * resolves * req * Update requirements/base.txt Co-authored-by: Omry Yadan <omry@fb.com> Co-authored-by: Omry Yadan <omry@fb.com>
This commit is contained in:
parent
f94b919b96
commit
e289e45120
|
@ -10,14 +10,14 @@ from typing import Union, Dict, Any, Optional, Callable
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
|
||||
from pytorch_lightning.utilities.io import load as pl_load
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
|
||||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
|
||||
try:
|
||||
from omegaconf import Container
|
||||
except ImportError:
|
||||
pass
|
||||
Container = None
|
||||
else:
|
||||
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )
|
||||
|
||||
|
@ -332,11 +332,25 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
||||
"""
|
||||
Args:
|
||||
config_yaml: path to new YAML file
|
||||
hparams: parameters to be saved
|
||||
"""
|
||||
if not os.path.isdir(os.path.dirname(config_yaml)):
|
||||
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
|
||||
|
||||
if Container is not None and isinstance(hparams, Container):
|
||||
from omegaconf import OmegaConf
|
||||
OmegaConf.save(hparams, config_yaml, resolve=True)
|
||||
return
|
||||
|
||||
# saving the standard way
|
||||
if isinstance(hparams, Namespace):
|
||||
hparams = vars(hparams)
|
||||
elif isinstance(hparams, AttributeDict):
|
||||
hparams = dict(hparams)
|
||||
assert isinstance(hparams, dict)
|
||||
|
||||
with open(config_yaml, 'w', newline='') as fp:
|
||||
yaml.dump(hparams, fp)
|
||||
|
|
|
@ -17,11 +17,6 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
try:
|
||||
from omegaconf import Container
|
||||
except ImportError:
|
||||
Container = None
|
||||
|
||||
|
||||
class TensorBoardLogger(LightningLoggerBase):
|
||||
r"""
|
||||
|
@ -156,14 +151,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
|
||||
|
||||
# save the metatags file
|
||||
if Container is not None:
|
||||
if isinstance(self.hparams, Container):
|
||||
from omegaconf import OmegaConf
|
||||
OmegaConf.save(self.hparams, hparams_file, resolve=True)
|
||||
else:
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
else:
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
|
|
|
@ -101,7 +101,7 @@ from pytorch_lightning.overrides.data_parallel import (
|
|||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.io import load as pl_load
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
|
||||
try:
|
||||
import torch_xla
|
||||
|
|
|
@ -5,4 +5,5 @@ tqdm>=4.41.0
|
|||
torch>=1.3
|
||||
tensorboard>=1.14
|
||||
future>=0.17.1 # required for builtins in setup.py
|
||||
pyyaml>=3.13
|
||||
# pyyaml>=3.13
|
||||
PyYAML>=5.1 # OmegaConf requirement
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch
|
|||
from omegaconf import OmegaConf, Container
|
||||
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
@ -407,3 +408,21 @@ def test_hparams_pickle(tmpdir):
|
|||
assert ad == pickle.loads(pkl)
|
||||
pkl = cloudpickle.dumps(ad)
|
||||
assert ad == pickle.loads(pkl)
|
||||
|
||||
|
||||
def test_hparams_save_yaml(tmpdir):
|
||||
hparams = dict(batch_size=32, learning_rate=0.001, data_root='./any/path/here',
|
||||
nasted=dict(any_num=123, anystr='abcd'))
|
||||
path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')
|
||||
|
||||
save_hparams_to_yaml(path_yaml, hparams)
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
|
||||
save_hparams_to_yaml(path_yaml, Namespace(**hparams))
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
|
||||
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
|
||||
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
|
||||
assert load_hparams_from_yaml(path_yaml) == hparams
|
||||
|
|
|
@ -19,7 +19,7 @@ from pytorch_lightning.core.saving import (
|
|||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.io import load as pl_load
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue