From b7c05d279c492a8ceea9bc00ecc6adcb55cc6b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 10 Feb 2023 09:22:56 +0100 Subject: [PATCH] Move `Trainer._log_hyperparams` to an utility (#16712) --- src/lightning/pytorch/loggers/utilities.py | 42 +++++++++++++++++++++ src/lightning/pytorch/trainer/trainer.py | 43 +--------------------- tests/tests_pytorch/loggers/test_logger.py | 9 ++--- tests/tests_pytorch/models/test_hparams.py | 5 +-- 4 files changed, 48 insertions(+), 51 deletions(-) diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index 5cdcea10cf..69c9a17c70 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -16,6 +16,9 @@ from pathlib import Path from typing import Any, List, Tuple, Union +from torch import Tensor + +import lightning.pytorch as pl from lightning.pytorch.callbacks import Checkpoint @@ -52,3 +55,42 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict) ) checkpoints = [c for c in checkpoints if c[1] not in logged_model_time.keys() or logged_model_time[c[1]] < c[0]] return checkpoints + + +def _log_hyperparams(trainer: "pl.Trainer") -> None: + if not trainer.loggers: + return + + pl_module = trainer.lightning_module + datamodule_log_hyperparams = trainer.datamodule._log_hyperparams if trainer.datamodule is not None else False + + hparams_initial = None + if pl_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = trainer.datamodule.hparams_initial + lightning_hparams = pl_module.hparams_initial + inconsistent_keys = [] + for key in lightning_hparams.keys() & datamodule_hparams.keys(): + lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] + if type(lm_val) != type(dm_val): + inconsistent_keys.append(key) + elif isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val): + inconsistent_keys.append(key) + elif lm_val != dm_val: + inconsistent_keys.append(key) + if inconsistent_keys: + raise RuntimeError( + f"Error while merging hparams: the keys {inconsistent_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams " + "but have different values." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif pl_module._log_hyperparams: + hparams_initial = pl_module.hparams_initial + elif datamodule_log_hyperparams: + hparams_initial = trainer.datamodule.hparams_initial + + for logger in trainer.loggers: + if hparams_initial is not None: + logger.log_hyperparams(hparams_initial) + logger.log_graph(pl_module) + logger.save() diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 49fc3bbbcc..e6a7c82943 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -37,7 +37,6 @@ import torch.distributed as dist from lightning_utilities.core.apply_func import apply_to_collection from lightning_utilities.core.imports import module_available from packaging.version import Version -from torch import Tensor from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -54,6 +53,7 @@ from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter from lightning.pytorch.core.datamodule import LightningDataModule from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers.tensorboard import TensorBoardLogger +from lightning.pytorch.loggers.utilities import _log_hyperparams from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop from lightning.pytorch.loops.dataloader.evaluation_loop import _EvaluationLoop from lightning.pytorch.loops.fit_loop import _FitLoop @@ -903,7 +903,7 @@ class Trainer: self._call_callback_hooks("on_fit_start") self._call_lightning_module_hook("on_fit_start") - self._log_hyperparams() + _log_hyperparams(self) if self.strategy.restore_checkpoint_after_setup: log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}") @@ -936,45 +936,6 @@ class Trainer: return results - def _log_hyperparams(self) -> None: - if not self.loggers: - return - # log hyper-parameters - hparams_initial = None - - # save exp to get started (this is where the first experiment logs are written) - datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False - - if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial - lightning_hparams = self.lightning_module.hparams_initial - inconsistent_keys = [] - for key in lightning_hparams.keys() & datamodule_hparams.keys(): - lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] - if type(lm_val) != type(dm_val): - inconsistent_keys.append(key) - elif isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val): - inconsistent_keys.append(key) - elif lm_val != dm_val: - inconsistent_keys.append(key) - if inconsistent_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {inconsistent_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams " - "but have different values." - ) - hparams_initial = {**lightning_hparams, **datamodule_hparams} - elif self.lightning_module._log_hyperparams: - hparams_initial = self.lightning_module.hparams_initial - elif datamodule_log_hyperparams: - hparams_initial = self.datamodule.hparams_initial - - for logger in self.loggers: - if hparams_initial is not None: - logger.log_hyperparams(hparams_initial) - logger.log_graph(self.lightning_module) - logger.save() - def _teardown(self) -> None: """This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback; those are handled by :meth:`_call_teardown_hook`.""" diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index fec7f02a5b..aed8fb8e1c 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -27,7 +27,6 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.loggers import Logger, TensorBoardLogger from lightning.pytorch.loggers.logger import DummyExperiment, DummyLogger from lightning.pytorch.loggers.utilities import _scan_checkpoints -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only @@ -253,7 +252,7 @@ def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger): @patch("lightning.pytorch.loggers.tensorboard.TensorBoardLogger.log_hyperparams") -def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir): +def test_log_hyperparams_key_collision(_, tmpdir): class TestModel(BoringModel): def __init__(self, hparams: Dict[str, Any]) -> None: super().__init__() @@ -269,7 +268,6 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir): same_params = {1: 1, "2": 2, "three": 3.0, "test": _Test(), "4": torch.tensor(4)} model = TestModel(same_params) - dm = TestDataModule(same_params) trainer = Trainer( default_root_dir=tmpdir, @@ -289,7 +287,6 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir): obj_params = deepcopy(same_params) obj_params["test"] = _Test() model = TestModel(same_params) - dm = TestDataModule(obj_params) trainer.fit(model) diff_params = deepcopy(same_params) @@ -307,7 +304,7 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir): enable_progress_bar=False, enable_model_summary=False, ) - with pytest.raises(MisconfigurationException, match="Error while merging hparams"): + with pytest.raises(RuntimeError, match="Error while merging hparams"): trainer.fit(model, dm) tensor_params = deepcopy(same_params) @@ -325,7 +322,7 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir): enable_progress_bar=False, enable_model_summary=False, ) - with pytest.raises(MisconfigurationException, match="Error while merging hparams"): + with pytest.raises(RuntimeError, match="Error while merging hparams"): trainer.fit(model, dm) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 84645386e2..0839975db7 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -36,7 +36,6 @@ from lightning.pytorch.core.saving import load_hparams_from_yaml, save_hparams_t from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable -from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: @@ -897,12 +896,10 @@ def test_no_datamodule_for_hparams(tmpdir): def test_colliding_hparams(tmpdir): - model = SaveHparamsModel({"data_dir": "abc", "arg2": "abc"}) data = DataModuleWithHparams({"data_dir": "foo"}) - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=CSVLogger(tmpdir)) - with pytest.raises(MisconfigurationException, match=r"Error while merging hparams:"): + with pytest.raises(RuntimeError, match=r"Error while merging hparams:"): trainer.fit(model, datamodule=data)