Move `Trainer._log_hyperparams` to an utility (#16712)
This commit is contained in:
parent
4b2cf36e77
commit
b7c05d279c
|
@ -16,6 +16,9 @@
|
|||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks import Checkpoint
|
||||
|
||||
|
||||
|
@ -52,3 +55,42 @@ def _scan_checkpoints(checkpoint_callback: Checkpoint, logged_model_time: dict)
|
|||
)
|
||||
checkpoints = [c for c in checkpoints if c[1] not in logged_model_time.keys() or logged_model_time[c[1]] < c[0]]
|
||||
return checkpoints
|
||||
|
||||
|
||||
def _log_hyperparams(trainer: "pl.Trainer") -> None:
|
||||
if not trainer.loggers:
|
||||
return
|
||||
|
||||
pl_module = trainer.lightning_module
|
||||
datamodule_log_hyperparams = trainer.datamodule._log_hyperparams if trainer.datamodule is not None else False
|
||||
|
||||
hparams_initial = None
|
||||
if pl_module._log_hyperparams and datamodule_log_hyperparams:
|
||||
datamodule_hparams = trainer.datamodule.hparams_initial
|
||||
lightning_hparams = pl_module.hparams_initial
|
||||
inconsistent_keys = []
|
||||
for key in lightning_hparams.keys() & datamodule_hparams.keys():
|
||||
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
|
||||
if type(lm_val) != type(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif lm_val != dm_val:
|
||||
inconsistent_keys.append(key)
|
||||
if inconsistent_keys:
|
||||
raise RuntimeError(
|
||||
f"Error while merging hparams: the keys {inconsistent_keys} are present "
|
||||
"in both the LightningModule's and LightningDataModule's hparams "
|
||||
"but have different values."
|
||||
)
|
||||
hparams_initial = {**lightning_hparams, **datamodule_hparams}
|
||||
elif pl_module._log_hyperparams:
|
||||
hparams_initial = pl_module.hparams_initial
|
||||
elif datamodule_log_hyperparams:
|
||||
hparams_initial = trainer.datamodule.hparams_initial
|
||||
|
||||
for logger in trainer.loggers:
|
||||
if hparams_initial is not None:
|
||||
logger.log_hyperparams(hparams_initial)
|
||||
logger.log_graph(pl_module)
|
||||
logger.save()
|
||||
|
|
|
@ -37,7 +37,6 @@ import torch.distributed as dist
|
|||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from lightning_utilities.core.imports import module_available
|
||||
from packaging.version import Version
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -54,6 +53,7 @@ from lightning.pytorch.callbacks.prediction_writer import BasePredictionWriter
|
|||
from lightning.pytorch.core.datamodule import LightningDataModule
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
||||
from lightning.pytorch.loggers.utilities import _log_hyperparams
|
||||
from lightning.pytorch.loops import _PredictionLoop, _TrainingEpochLoop
|
||||
from lightning.pytorch.loops.dataloader.evaluation_loop import _EvaluationLoop
|
||||
from lightning.pytorch.loops.fit_loop import _FitLoop
|
||||
|
@ -903,7 +903,7 @@ class Trainer:
|
|||
self._call_callback_hooks("on_fit_start")
|
||||
self._call_lightning_module_hook("on_fit_start")
|
||||
|
||||
self._log_hyperparams()
|
||||
_log_hyperparams(self)
|
||||
|
||||
if self.strategy.restore_checkpoint_after_setup:
|
||||
log.detail(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")
|
||||
|
@ -936,45 +936,6 @@ class Trainer:
|
|||
|
||||
return results
|
||||
|
||||
def _log_hyperparams(self) -> None:
|
||||
if not self.loggers:
|
||||
return
|
||||
# log hyper-parameters
|
||||
hparams_initial = None
|
||||
|
||||
# save exp to get started (this is where the first experiment logs are written)
|
||||
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
|
||||
|
||||
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
|
||||
datamodule_hparams = self.datamodule.hparams_initial
|
||||
lightning_hparams = self.lightning_module.hparams_initial
|
||||
inconsistent_keys = []
|
||||
for key in lightning_hparams.keys() & datamodule_hparams.keys():
|
||||
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
|
||||
if type(lm_val) != type(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val):
|
||||
inconsistent_keys.append(key)
|
||||
elif lm_val != dm_val:
|
||||
inconsistent_keys.append(key)
|
||||
if inconsistent_keys:
|
||||
raise MisconfigurationException(
|
||||
f"Error while merging hparams: the keys {inconsistent_keys} are present "
|
||||
"in both the LightningModule's and LightningDataModule's hparams "
|
||||
"but have different values."
|
||||
)
|
||||
hparams_initial = {**lightning_hparams, **datamodule_hparams}
|
||||
elif self.lightning_module._log_hyperparams:
|
||||
hparams_initial = self.lightning_module.hparams_initial
|
||||
elif datamodule_log_hyperparams:
|
||||
hparams_initial = self.datamodule.hparams_initial
|
||||
|
||||
for logger in self.loggers:
|
||||
if hparams_initial is not None:
|
||||
logger.log_hyperparams(hparams_initial)
|
||||
logger.log_graph(self.lightning_module)
|
||||
logger.save()
|
||||
|
||||
def _teardown(self) -> None:
|
||||
"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and
|
||||
Callback; those are handled by :meth:`_call_teardown_hook`."""
|
||||
|
|
|
@ -27,7 +27,6 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
|
|||
from lightning.pytorch.loggers import Logger, TensorBoardLogger
|
||||
from lightning.pytorch.loggers.logger import DummyExperiment, DummyLogger
|
||||
from lightning.pytorch.loggers.utilities import _scan_checkpoints
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
|
||||
|
||||
|
@ -253,7 +252,7 @@ def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger):
|
|||
|
||||
|
||||
@patch("lightning.pytorch.loggers.tensorboard.TensorBoardLogger.log_hyperparams")
|
||||
def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
|
||||
def test_log_hyperparams_key_collision(_, tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, hparams: Dict[str, Any]) -> None:
|
||||
super().__init__()
|
||||
|
@ -269,7 +268,6 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
|
|||
|
||||
same_params = {1: 1, "2": 2, "three": 3.0, "test": _Test(), "4": torch.tensor(4)}
|
||||
model = TestModel(same_params)
|
||||
dm = TestDataModule(same_params)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -289,7 +287,6 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
|
|||
obj_params = deepcopy(same_params)
|
||||
obj_params["test"] = _Test()
|
||||
model = TestModel(same_params)
|
||||
dm = TestDataModule(obj_params)
|
||||
trainer.fit(model)
|
||||
|
||||
diff_params = deepcopy(same_params)
|
||||
|
@ -307,7 +304,7 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
|
|||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
)
|
||||
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
|
||||
with pytest.raises(RuntimeError, match="Error while merging hparams"):
|
||||
trainer.fit(model, dm)
|
||||
|
||||
tensor_params = deepcopy(same_params)
|
||||
|
@ -325,7 +322,7 @@ def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
|
|||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
)
|
||||
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
|
||||
with pytest.raises(RuntimeError, match="Error while merging hparams"):
|
||||
trainer.fit(model, dm)
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,6 @@ from lightning.pytorch.core.saving import load_hparams_from_yaml, save_hparams_t
|
|||
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
|
||||
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
||||
from lightning.pytorch.utilities import _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
|
@ -897,12 +896,10 @@ def test_no_datamodule_for_hparams(tmpdir):
|
|||
|
||||
|
||||
def test_colliding_hparams(tmpdir):
|
||||
|
||||
model = SaveHparamsModel({"data_dir": "abc", "arg2": "abc"})
|
||||
data = DataModuleWithHparams({"data_dir": "foo"})
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=CSVLogger(tmpdir))
|
||||
with pytest.raises(MisconfigurationException, match=r"Error while merging hparams:"):
|
||||
with pytest.raises(RuntimeError, match=r"Error while merging hparams:"):
|
||||
trainer.fit(model, datamodule=data)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue