Add support to Tensorboard logger for OmegaConf hparams (#2846)

* Add support to Tensorboard logger for OmegaConf hparams

Address https://github.com/PyTorchLightning/pytorch-lightning/issues/2844

We check if we can import omegaconf, and if the hparams are omegaconf instances. if so, we use OmegaConf.merge to preserve the typing, such that saving hparams to yaml actually triggers the OmegaConf branch

* avalaible

* chlog

* test

Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
ananthsub 2020-08-07 06:13:21 -07:00 committed by GitHub
parent 91b0d46cd5
commit b39f4798a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 24 deletions

View File

@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935))
- Added support to Tensorboard logger for OmegaConf `hparams` ([#2846](https://github.com/PyTorchLightning/pytorch-lightning/pull/2846))
### Changed
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

View File

@ -17,7 +17,9 @@ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
try:
from omegaconf import Container
except ImportError:
Container = None
OMEGACONF_AVAILABLE = False
else:
OMEGACONF_AVAILABLE = True
# the older shall be on the top
CHECKPOINT_PAST_HPARAMS_KEYS = (
@ -327,7 +329,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
if Container is not None and isinstance(hparams, Container):
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
from omegaconf import OmegaConf
OmegaConf.save(hparams, config_yaml, resolve=True)
return

View File

@ -17,6 +17,13 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
try:
from omegaconf import Container, OmegaConf
except ImportError:
OMEGACONF_AVAILABLE = False
else:
OMEGACONF_AVAILABLE = True
class TensorBoardLogger(LightningLoggerBase):
r"""
@ -112,7 +119,10 @@ class TensorBoardLogger(LightningLoggerBase):
params = self._convert_params(params)
# store params to output
self.hparams.update(params)
if OMEGACONF_AVAILABLE and isinstance(params, Container):
self.hparams = OmegaConf.merge(self.hparams, params)
else:
self.hparams.update(params)
# format params into the suitable for tensorboard
params = self._flatten_dict(params)

View File

@ -132,7 +132,9 @@ else:
try:
from omegaconf import Container
except ImportError:
Container = None
OMEGACONF_AVAILABLE = False
else:
OMEGACONF_AVAILABLE = True
class TrainerIOMixin(ABC):
@ -390,7 +392,7 @@ class TrainerIOMixin(ABC):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# add arguments to the checkpoint
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if Container is not None:
if OMEGACONF_AVAILABLE:
if isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)

View File

@ -4,6 +4,7 @@ from argparse import Namespace
import pytest
import torch
import yaml
from omegaconf import OmegaConf
from packaging import version
from pytorch_lightning import Trainer
@ -11,29 +12,28 @@ from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.5.0'),
reason='Minimal PT version is set to 1.5')
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("1.5.0"),
reason="Minimal PT version is set to 1.5",
)
def test_tensorboard_hparams_reload(tmpdir):
model = EvalModelTemplate()
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
)
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model)
folder_path = trainer.logger.log_dir
# make sure yaml is there
with open(os.path.join(folder_path, 'hparams.yaml')) as file:
with open(os.path.join(folder_path, "hparams.yaml")) as file:
# The FullLoader parameter handles the conversion from YAML
# scalar values to Python the dictionary format
yaml_params = yaml.safe_load(file)
assert yaml_params['b1'] == 0.5
assert yaml_params["b1"] == 0.5
assert len(yaml_params.keys()) == 10
# verify artifacts
assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1
assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1
#
# # verify tb logs
# event_acc = EventAccumulator(folder_path)
@ -88,13 +88,13 @@ def test_tensorboard_named_version(tmpdir):
assert os.listdir(tmpdir / name / expected_version)
@pytest.mark.parametrize("name", ['', None])
@pytest.mark.parametrize("name", ["", None])
def test_tensorboard_no_name(tmpdir, name):
"""Verify that None or empty name works"""
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
assert logger.root_dir == tmpdir
assert os.listdir(tmpdir / 'version_0')
assert os.listdir(tmpdir / "version_0")
@pytest.mark.parametrize("step_idx", [10, None])
@ -104,7 +104,7 @@ def test_tensorboard_log_metrics(tmpdir, step_idx):
"float": 0.3,
"int": 1,
"FloatTensor": torch.tensor(0.1),
"IntTensor": torch.tensor(1)
"IntTensor": torch.tensor(1),
}
logger.log_metrics(metrics, step_idx)
@ -116,10 +116,10 @@ def test_tensorboard_log_hyperparams(tmpdir):
"int": 1,
"string": "abc",
"bool": True,
"dict": {'a': {'b': 'c'}},
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
"namespace": Namespace(foo=Namespace(bar='buzz')),
"layer": torch.nn.BatchNorm1d
"namespace": Namespace(foo=Namespace(bar="buzz")),
"layer": torch.nn.BatchNorm1d,
}
logger.log_hyperparams(hparams)
@ -131,10 +131,28 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
"int": 1,
"string": "abc",
"bool": True,
"dict": {'a': {'b': 'c'}},
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
"namespace": Namespace(foo=Namespace(bar='buzz')),
"layer": torch.nn.BatchNorm1d
"namespace": Namespace(foo=Namespace(bar="buzz")),
"layer": torch.nn.BatchNorm1d,
}
metrics = {'abc': torch.tensor([0.54])}
metrics = {"abc": torch.tensor([0.54])}
logger.log_hyperparams(hparams, metrics)
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
logger = TensorBoardLogger(tmpdir)
hparams = {
"float": 0.3,
"int": 1,
"string": "abc",
"bool": True,
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
# "namespace": Namespace(foo=Namespace(bar="buzz")),
# "layer": torch.nn.BatchNorm1d,
}
hparams = OmegaConf.create(hparams)
metrics = {"abc": torch.tensor([0.54])}
logger.log_hyperparams(hparams, metrics)