Add support to Tensorboard logger for OmegaConf hparams (#2846)
* Add support to Tensorboard logger for OmegaConf hparams Address https://github.com/PyTorchLightning/pytorch-lightning/issues/2844 We check if we can import omegaconf, and if the hparams are omegaconf instances. if so, we use OmegaConf.merge to preserve the typing, such that saving hparams to yaml actually triggers the OmegaConf branch * avalaible * chlog * test Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
parent
91b0d46cd5
commit
b39f4798a6
|
@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935))
|
||||
|
||||
- Added support to Tensorboard logger for OmegaConf `hparams` ([#2846](https://github.com/PyTorchLightning/pytorch-lightning/pull/2846))
|
||||
|
||||
### Changed
|
||||
|
||||
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
|
||||
|
|
|
@ -17,7 +17,9 @@ ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
|||
try:
|
||||
from omegaconf import Container
|
||||
except ImportError:
|
||||
Container = None
|
||||
OMEGACONF_AVAILABLE = False
|
||||
else:
|
||||
OMEGACONF_AVAILABLE = True
|
||||
|
||||
# the older shall be on the top
|
||||
CHECKPOINT_PAST_HPARAMS_KEYS = (
|
||||
|
@ -327,7 +329,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
|||
if not os.path.isdir(os.path.dirname(config_yaml)):
|
||||
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
|
||||
|
||||
if Container is not None and isinstance(hparams, Container):
|
||||
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
|
||||
from omegaconf import OmegaConf
|
||||
OmegaConf.save(hparams, config_yaml, resolve=True)
|
||||
return
|
||||
|
|
|
@ -17,6 +17,13 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
try:
|
||||
from omegaconf import Container, OmegaConf
|
||||
except ImportError:
|
||||
OMEGACONF_AVAILABLE = False
|
||||
else:
|
||||
OMEGACONF_AVAILABLE = True
|
||||
|
||||
|
||||
class TensorBoardLogger(LightningLoggerBase):
|
||||
r"""
|
||||
|
@ -112,7 +119,10 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
params = self._convert_params(params)
|
||||
|
||||
# store params to output
|
||||
self.hparams.update(params)
|
||||
if OMEGACONF_AVAILABLE and isinstance(params, Container):
|
||||
self.hparams = OmegaConf.merge(self.hparams, params)
|
||||
else:
|
||||
self.hparams.update(params)
|
||||
|
||||
# format params into the suitable for tensorboard
|
||||
params = self._flatten_dict(params)
|
||||
|
|
|
@ -132,7 +132,9 @@ else:
|
|||
try:
|
||||
from omegaconf import Container
|
||||
except ImportError:
|
||||
Container = None
|
||||
OMEGACONF_AVAILABLE = False
|
||||
else:
|
||||
OMEGACONF_AVAILABLE = True
|
||||
|
||||
|
||||
class TrainerIOMixin(ABC):
|
||||
|
@ -390,7 +392,7 @@ class TrainerIOMixin(ABC):
|
|||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
|
||||
# add arguments to the checkpoint
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
||||
if Container is not None:
|
||||
if OMEGACONF_AVAILABLE:
|
||||
if isinstance(model.hparams, Container):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from argparse import Namespace
|
|||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from packaging import version
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
@ -11,29 +12,28 @@ from pytorch_lightning.loggers import TensorBoardLogger
|
|||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.5.0'),
|
||||
reason='Minimal PT version is set to 1.5')
|
||||
@pytest.mark.skipif(
|
||||
version.parse(torch.__version__) < version.parse("1.5.0"),
|
||||
reason="Minimal PT version is set to 1.5",
|
||||
)
|
||||
def test_tensorboard_hparams_reload(tmpdir):
|
||||
model = EvalModelTemplate()
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
|
||||
trainer.fit(model)
|
||||
|
||||
folder_path = trainer.logger.log_dir
|
||||
|
||||
# make sure yaml is there
|
||||
with open(os.path.join(folder_path, 'hparams.yaml')) as file:
|
||||
with open(os.path.join(folder_path, "hparams.yaml")) as file:
|
||||
# The FullLoader parameter handles the conversion from YAML
|
||||
# scalar values to Python the dictionary format
|
||||
yaml_params = yaml.safe_load(file)
|
||||
assert yaml_params['b1'] == 0.5
|
||||
assert yaml_params["b1"] == 0.5
|
||||
assert len(yaml_params.keys()) == 10
|
||||
|
||||
# verify artifacts
|
||||
assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1
|
||||
assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1
|
||||
#
|
||||
# # verify tb logs
|
||||
# event_acc = EventAccumulator(folder_path)
|
||||
|
@ -88,13 +88,13 @@ def test_tensorboard_named_version(tmpdir):
|
|||
assert os.listdir(tmpdir / name / expected_version)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name", ['', None])
|
||||
@pytest.mark.parametrize("name", ["", None])
|
||||
def test_tensorboard_no_name(tmpdir, name):
|
||||
"""Verify that None or empty name works"""
|
||||
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
|
||||
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
|
||||
assert logger.root_dir == tmpdir
|
||||
assert os.listdir(tmpdir / 'version_0')
|
||||
assert os.listdir(tmpdir / "version_0")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("step_idx", [10, None])
|
||||
|
@ -104,7 +104,7 @@ def test_tensorboard_log_metrics(tmpdir, step_idx):
|
|||
"float": 0.3,
|
||||
"int": 1,
|
||||
"FloatTensor": torch.tensor(0.1),
|
||||
"IntTensor": torch.tensor(1)
|
||||
"IntTensor": torch.tensor(1),
|
||||
}
|
||||
logger.log_metrics(metrics, step_idx)
|
||||
|
||||
|
@ -116,10 +116,10 @@ def test_tensorboard_log_hyperparams(tmpdir):
|
|||
"int": 1,
|
||||
"string": "abc",
|
||||
"bool": True,
|
||||
"dict": {'a': {'b': 'c'}},
|
||||
"dict": {"a": {"b": "c"}},
|
||||
"list": [1, 2, 3],
|
||||
"namespace": Namespace(foo=Namespace(bar='buzz')),
|
||||
"layer": torch.nn.BatchNorm1d
|
||||
"namespace": Namespace(foo=Namespace(bar="buzz")),
|
||||
"layer": torch.nn.BatchNorm1d,
|
||||
}
|
||||
logger.log_hyperparams(hparams)
|
||||
|
||||
|
@ -131,10 +131,28 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
|
|||
"int": 1,
|
||||
"string": "abc",
|
||||
"bool": True,
|
||||
"dict": {'a': {'b': 'c'}},
|
||||
"dict": {"a": {"b": "c"}},
|
||||
"list": [1, 2, 3],
|
||||
"namespace": Namespace(foo=Namespace(bar='buzz')),
|
||||
"layer": torch.nn.BatchNorm1d
|
||||
"namespace": Namespace(foo=Namespace(bar="buzz")),
|
||||
"layer": torch.nn.BatchNorm1d,
|
||||
}
|
||||
metrics = {'abc': torch.tensor([0.54])}
|
||||
metrics = {"abc": torch.tensor([0.54])}
|
||||
logger.log_hyperparams(hparams, metrics)
|
||||
|
||||
|
||||
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
|
||||
logger = TensorBoardLogger(tmpdir)
|
||||
hparams = {
|
||||
"float": 0.3,
|
||||
"int": 1,
|
||||
"string": "abc",
|
||||
"bool": True,
|
||||
"dict": {"a": {"b": "c"}},
|
||||
"list": [1, 2, 3],
|
||||
# "namespace": Namespace(foo=Namespace(bar="buzz")),
|
||||
# "layer": torch.nn.BatchNorm1d,
|
||||
}
|
||||
hparams = OmegaConf.create(hparams)
|
||||
|
||||
metrics = {"abc": torch.tensor([0.54])}
|
||||
logger.log_hyperparams(hparams, metrics)
|
||||
|
|
Loading…
Reference in New Issue