diff --git a/CHANGELOG.md b/CHANGELOG.md index f29c17ddfe..7e9788334c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935)) +- Added support to Tensorboard logger for OmegaConf `hparams` ([#2846](https://github.com/PyTorchLightning/pytorch-lightning/pull/2846)) + ### Changed - Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 37c63de180..20fa8a8840 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,7 +17,9 @@ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) try: from omegaconf import Container except ImportError: - Container = None + OMEGACONF_AVAILABLE = False +else: + OMEGACONF_AVAILABLE = True # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ( @@ -327,7 +329,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: if not os.path.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.') - if Container is not None and isinstance(hparams, Container): + if OMEGACONF_AVAILABLE and isinstance(hparams, Container): from omegaconf import OmegaConf OmegaConf.save(hparams, config_yaml, resolve=True) return diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 23394b93f7..f88b5f97cf 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -17,6 +17,13 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only +try: + from omegaconf import Container, OmegaConf +except ImportError: + OMEGACONF_AVAILABLE = False +else: + OMEGACONF_AVAILABLE = True + class TensorBoardLogger(LightningLoggerBase): r""" @@ -112,7 +119,10 @@ class TensorBoardLogger(LightningLoggerBase): params = self._convert_params(params) # store params to output - self.hparams.update(params) + if OMEGACONF_AVAILABLE and isinstance(params, Container): + self.hparams = OmegaConf.merge(self.hparams, params) + else: + self.hparams.update(params) # format params into the suitable for tensorboard params = self._flatten_dict(params) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 666dfbb258..1c22390355 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -132,7 +132,9 @@ else: try: from omegaconf import Container except ImportError: - Container = None + OMEGACONF_AVAILABLE = False +else: + OMEGACONF_AVAILABLE = True class TrainerIOMixin(ABC): @@ -390,7 +392,7 @@ class TrainerIOMixin(ABC): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # add arguments to the checkpoint checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - if Container is not None: + if OMEGACONF_AVAILABLE: if isinstance(model.hparams, Container): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index e5aec716a2..4ae1ba8484 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -4,6 +4,7 @@ from argparse import Namespace import pytest import torch import yaml +from omegaconf import OmegaConf from packaging import version from pytorch_lightning import Trainer @@ -11,29 +12,28 @@ from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.5.0'), - reason='Minimal PT version is set to 1.5') +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("1.5.0"), + reason="Minimal PT version is set to 1.5", +) def test_tensorboard_hparams_reload(tmpdir): model = EvalModelTemplate() - trainer = Trainer( - max_epochs=1, - default_root_dir=tmpdir, - ) + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) folder_path = trainer.logger.log_dir # make sure yaml is there - with open(os.path.join(folder_path, 'hparams.yaml')) as file: + with open(os.path.join(folder_path, "hparams.yaml")) as file: # The FullLoader parameter handles the conversion from YAML # scalar values to Python the dictionary format yaml_params = yaml.safe_load(file) - assert yaml_params['b1'] == 0.5 + assert yaml_params["b1"] == 0.5 assert len(yaml_params.keys()) == 10 # verify artifacts - assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1 + assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1 # # # verify tb logs # event_acc = EventAccumulator(folder_path) @@ -88,13 +88,13 @@ def test_tensorboard_named_version(tmpdir): assert os.listdir(tmpdir / name / expected_version) -@pytest.mark.parametrize("name", ['', None]) +@pytest.mark.parametrize("name", ["", None]) def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written assert logger.root_dir == tmpdir - assert os.listdir(tmpdir / 'version_0') + assert os.listdir(tmpdir / "version_0") @pytest.mark.parametrize("step_idx", [10, None]) @@ -104,7 +104,7 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): "float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), - "IntTensor": torch.tensor(1) + "IntTensor": torch.tensor(1), } logger.log_metrics(metrics, step_idx) @@ -116,10 +116,10 @@ def test_tensorboard_log_hyperparams(tmpdir): "int": 1, "string": "abc", "bool": True, - "dict": {'a': {'b': 'c'}}, + "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], - "namespace": Namespace(foo=Namespace(bar='buzz')), - "layer": torch.nn.BatchNorm1d + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, } logger.log_hyperparams(hparams) @@ -131,10 +131,28 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir): "int": 1, "string": "abc", "bool": True, - "dict": {'a': {'b': 'c'}}, + "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], - "namespace": Namespace(foo=Namespace(bar='buzz')), - "layer": torch.nn.BatchNorm1d + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, } - metrics = {'abc': torch.tensor([0.54])} + metrics = {"abc": torch.tensor([0.54])} + logger.log_hyperparams(hparams, metrics) + + +def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): + logger = TensorBoardLogger(tmpdir) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": {"a": {"b": "c"}}, + "list": [1, 2, 3], + # "namespace": Namespace(foo=Namespace(bar="buzz")), + # "layer": torch.nn.BatchNorm1d, + } + hparams = OmegaConf.create(hparams) + + metrics = {"abc": torch.tensor([0.54])} logger.log_hyperparams(hparams, metrics)