Move `Trainer._log_hyperparams` to an utility (#16712)

This commit is contained in:
Carlos Mocholí 2023-02-10 09:22:56 +01:00 committed by GitHub
parent 4b2cf36e77
commit b7c05d279c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 51 deletions

View File

@ -16,6 +16,9 @@
from pathlib import Path
from typing import Any, List, Tuple, Union
from torch import Tensor
import lightning.pytorch as pl
from lightning.pytorch.callbacks import Checkpoint
@ -52,3 +55,42 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict)
)
checkpoints = [c for c in checkpoints if c[1] not in logged_model_time.keys() or logged_model_time[c[1]] < c[0]]
return checkpoints
def _log_hyperparams(trainer: "pl.Trainer") -> None:
if not trainer.loggers:
return
pl_module = trainer.lightning_module
datamodule_log_hyperparams = trainer.datamodule._log_hyperparams if trainer.datamodule is not None else False
hparams_initial = None
if pl_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = trainer.datamodule.hparams_initial
lightning_hparams = pl_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if type(lm_val) != type(dm_val):
inconsistent_keys.append(key)
elif isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val):
inconsistent_keys.append(key)
elif lm_val != dm_val:
inconsistent_keys.append(key)
if inconsistent_keys:
raise RuntimeError(
f"Error while merging hparams: the keys {inconsistent_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams "
"but have different values."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif pl_module._log_hyperparams:
hparams_initial = pl_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = trainer.datamodule.hparams_initial
for logger in trainer.loggers:
if hparams_initial is not None:
logger.log_hyperparams(hparams_initial)
logger.log_graph(pl_module)
logger.save()

View File

@ -37,7 +37,6 @@ import torch.distributed as dist
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.imports import module_available
from packaging.version import Version
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
@ -54,6 +53,7 @@ from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning.pytorch.loggers import Logger
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.loggers.utilities import _log_hyperparams
from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop
from lightning.pytorch.loops.dataloader.evaluation_loop import _EvaluationLoop
from lightning.pytorch.loops.fit_loop import _FitLoop
@ -903,7 +903,7 @@ class Trainer:
self._call_callback_hooks("on_fit_start")
self._call_lightning_module_hook("on_fit_start")
self._log_hyperparams()
_log_hyperparams(self)
if self.strategy.restore_checkpoint_after_setup:
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
@ -936,45 +936,6 @@ class Trainer:
return results
def _log_hyperparams(self) -> None:
if not self.loggers:
return
# log hyper-parameters
hparams_initial = None
# save exp to get started (this is where the first experiment logs are written)
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial
inconsistent_keys = []
for key in lightning_hparams.keys() & datamodule_hparams.keys():
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
if type(lm_val) != type(dm_val):
inconsistent_keys.append(key)
elif isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val):
inconsistent_keys.append(key)
elif lm_val != dm_val:
inconsistent_keys.append(key)
if inconsistent_keys:
raise MisconfigurationException(
f"Error while merging hparams: the keys {inconsistent_keys} are present "
"in both the LightningModule's and LightningDataModule's hparams "
"but have different values."
)
hparams_initial = {**lightning_hparams, **datamodule_hparams}
elif self.lightning_module._log_hyperparams:
hparams_initial = self.lightning_module.hparams_initial
elif datamodule_log_hyperparams:
hparams_initial = self.datamodule.hparams_initial
for logger in self.loggers:
if hparams_initial is not None:
logger.log_hyperparams(hparams_initial)
logger.log_graph(self.lightning_module)
logger.save()
def _teardown(self) -> None:
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
Callback; those are handled by :meth:`_call_teardown_hook`."""

View File

@ -27,7 +27,6 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.loggers import Logger, TensorBoardLogger
from lightning.pytorch.loggers.logger import DummyExperiment, DummyLogger
from lightning.pytorch.loggers.utilities import _scan_checkpoints
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_only
@ -253,7 +252,7 @@ def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger):
@patch("lightning.pytorch.loggers.tensorboard.TensorBoardLogger.log_hyperparams")
def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
def test_log_hyperparams_key_collision(_, tmpdir):
class TestModel(BoringModel):
def __init__(self, hparams: Dict[str, Any]) -> None:
super().__init__()
@ -269,7 +268,6 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
same_params = {1: 1, "2": 2, "three": 3.0, "test": _Test(), "4": torch.tensor(4)}
model = TestModel(same_params)
dm = TestDataModule(same_params)
trainer = Trainer(
default_root_dir=tmpdir,
@ -289,7 +287,6 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
obj_params = deepcopy(same_params)
obj_params["test"] = _Test()
model = TestModel(same_params)
dm = TestDataModule(obj_params)
trainer.fit(model)
diff_params = deepcopy(same_params)
@ -307,7 +304,7 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
enable_progress_bar=False,
enable_model_summary=False,
)
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
with pytest.raises(RuntimeError, match="Error while merging hparams"):
trainer.fit(model, dm)
tensor_params = deepcopy(same_params)
@ -325,7 +322,7 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
enable_progress_bar=False,
enable_model_summary=False,
)
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
with pytest.raises(RuntimeError, match="Error while merging hparams"):
trainer.fit(model, dm)

View File

@ -36,7 +36,6 @@ from lightning.pytorch.core.saving import load_hparams_from_yaml, save_hparams_t
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
if _OMEGACONF_AVAILABLE:
@ -897,12 +896,10 @@ def test_no_datamodule_for_hparams(tmpdir):
def test_colliding_hparams(tmpdir):
model = SaveHparamsModel({"data_dir": "abc", "arg2": "abc"})
data = DataModuleWithHparams({"data_dir": "foo"})
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=CSVLogger(tmpdir))
with pytest.raises(MisconfigurationException, match=r"Error while merging hparams:"):
with pytest.raises(RuntimeError, match=r"Error while merging hparams:"):
trainer.fit(model, datamodule=data)