From b39f4798a6859d2237b48b29b39a2390164612c1 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 7 Aug 2020 06:13:21 -0700 Subject: [PATCH] Add support to Tensorboard logger for OmegaConf hparams (#2846) * Add support to Tensorboard logger for OmegaConf hparams Address https://github.com/PyTorchLightning/pytorch-lightning/issues/2844 We check if we can import omegaconf, and if the hparams are omegaconf instances. if so, we use OmegaConf.merge to preserve the typing, such that saving hparams to yaml actually triggers the OmegaConf branch * avalaible * chlog * test Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 + pytorch_lightning/core/saving.py | 6 ++- pytorch_lightning/loggers/tensorboard.py | 12 ++++- pytorch_lightning/trainer/training_io.py | 6 ++- tests/loggers/test_tensorboard.py | 56 ++++++++++++++++-------- 5 files changed, 58 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f29c17ddfe..7e9788334c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935)) +- Added support to Tensorboard logger for OmegaConf `hparams` ([#2846](https://github.com/PyTorchLightning/pytorch-lightning/pull/2846)) + ### Changed - Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 37c63de180..20fa8a8840 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,7 +17,9 @@ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) try: from omegaconf import Container except ImportError: - Container = None + OMEGACONF_AVAILABLE = False +else: + OMEGACONF_AVAILABLE = True # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ( @@ -327,7 +329,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: if not os.path.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.') - if Container is not None and isinstance(hparams, Container): + if OMEGACONF_AVAILABLE and isinstance(hparams, Container): from omegaconf import OmegaConf OmegaConf.save(hparams, config_yaml, resolve=True) return diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 23394b93f7..f88b5f97cf 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -17,6 +17,13 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only +try: + from omegaconf import Container, OmegaConf +except ImportError: + OMEGACONF_AVAILABLE = False +else: + OMEGACONF_AVAILABLE = True + class TensorBoardLogger(LightningLoggerBase): r""" @@ -112,7 +119,10 @@ class TensorBoardLogger(LightningLoggerBase): params = self._convert_params(params) # store params to output - self.hparams.update(params) + if OMEGACONF_AVAILABLE and isinstance(params, Container): + self.hparams = OmegaConf.merge(self.hparams, params) + else: + self.hparams.update(params) # format params into the suitable for tensorboard params = self._flatten_dict(params) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 666dfbb258..1c22390355 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -132,7 +132,9 @@ else: try: from omegaconf import Container except ImportError: - Container = None + OMEGACONF_AVAILABLE = False +else: + OMEGACONF_AVAILABLE = True class TrainerIOMixin(ABC): @@ -390,7 +392,7 @@ class TrainerIOMixin(ABC): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # add arguments to the checkpoint checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - if Container is not None: + if OMEGACONF_AVAILABLE: if isinstance(model.hparams, Container): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index e5aec716a2..4ae1ba8484 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -4,6 +4,7 @@ from argparse import Namespace import pytest import torch import yaml +from omegaconf import OmegaConf from packaging import version from pytorch_lightning import Trainer @@ -11,29 +12,28 @@ from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.5.0'), - reason='Minimal PT version is set to 1.5') +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("1.5.0"), + reason="Minimal PT version is set to 1.5", +) def test_tensorboard_hparams_reload(tmpdir): model = EvalModelTemplate() - trainer = Trainer( - max_epochs=1, - default_root_dir=tmpdir, - ) + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) trainer.fit(model) folder_path = trainer.logger.log_dir # make sure yaml is there - with open(os.path.join(folder_path, 'hparams.yaml')) as file: + with open(os.path.join(folder_path, "hparams.yaml")) as file: # The FullLoader parameter handles the conversion from YAML # scalar values to Python the dictionary format yaml_params = yaml.safe_load(file) - assert yaml_params['b1'] == 0.5 + assert yaml_params["b1"] == 0.5 assert len(yaml_params.keys()) == 10 # verify artifacts - assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1 + assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1 # # # verify tb logs # event_acc = EventAccumulator(folder_path) @@ -88,13 +88,13 @@ def test_tensorboard_named_version(tmpdir): assert os.listdir(tmpdir / name / expected_version) -@pytest.mark.parametrize("name", ['', None]) +@pytest.mark.parametrize("name", ["", None]) def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written assert logger.root_dir == tmpdir - assert os.listdir(tmpdir / 'version_0') + assert os.listdir(tmpdir / "version_0") @pytest.mark.parametrize("step_idx", [10, None]) @@ -104,7 +104,7 @@ def test_tensorboard_log_metrics(tmpdir, step_idx): "float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), - "IntTensor": torch.tensor(1) + "IntTensor": torch.tensor(1), } logger.log_metrics(metrics, step_idx) @@ -116,10 +116,10 @@ def test_tensorboard_log_hyperparams(tmpdir): "int": 1, "string": "abc", "bool": True, - "dict": {'a': {'b': 'c'}}, + "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], - "namespace": Namespace(foo=Namespace(bar='buzz')), - "layer": torch.nn.BatchNorm1d + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, } logger.log_hyperparams(hparams) @@ -131,10 +131,28 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir): "int": 1, "string": "abc", "bool": True, - "dict": {'a': {'b': 'c'}}, + "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], - "namespace": Namespace(foo=Namespace(bar='buzz')), - "layer": torch.nn.BatchNorm1d + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, } - metrics = {'abc': torch.tensor([0.54])} + metrics = {"abc": torch.tensor([0.54])} + logger.log_hyperparams(hparams, metrics) + + +def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): + logger = TensorBoardLogger(tmpdir) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": {"a": {"b": "c"}}, + "list": [1, 2, 3], + # "namespace": Namespace(foo=Namespace(bar="buzz")), + # "layer": torch.nn.BatchNorm1d, + } + hparams = OmegaConf.create(hparams) + + metrics = {"abc": torch.tensor([0.54])} logger.log_hyperparams(hparams, metrics)