From 49caddde6e6757eb9e772c356c9ae6a1d9261756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 21 Nov 2023 15:44:18 +0100 Subject: [PATCH] Call `configure_model()` in `LM.load_from_checkpoint()` (#19036) --- src/lightning/pytorch/CHANGELOG.md | 3 +++ src/lightning/pytorch/core/hooks.py | 3 ++- src/lightning/pytorch/core/module.py | 7 +++++++ src/lightning/pytorch/core/saving.py | 4 ++++ tests/tests_pytorch/models/test_hooks.py | 20 ++++++++++++++++---- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3d6e7b1a88..9cefa18f7e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `Trainer.fit()` loop no longer calls `LightningModule.train()` at the start; it now preserves the user's configuration of frozen layers ([#18951](https://github.com/Lightning-AI/lightning/pull/18951)) +- The `LightningModule.load_from_checkpoint()` function now calls `.configure_model()` on the model if it is overridden, to ensure all layers can be loaded from the checkpoint ([#19036](https://github.com/Lightning-AI/lightning/pull/19036)) + + ### Deprecated - Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840)) diff --git a/src/lightning/pytorch/core/hooks.py b/src/lightning/pytorch/core/hooks.py index 722fb27cbc..2f510fe270 100644 --- a/src/lightning/pytorch/core/hooks.py +++ b/src/lightning/pytorch/core/hooks.py @@ -336,7 +336,8 @@ class ModelHooks: :meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager. This hook is called during each of fit/val/test/predict stages in the same process, so ensure that - implementation of this hook is idempotent. + implementation of this hook is **idempotent**, i.e., after the first time the hook is called, subsequent calls + to it should be a no-op. """ diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 14e78702d9..53bccb87e0 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1519,6 +1519,13 @@ class LightningModule( **class** to call it instead of the :class:`LightningModule` instance, or a ``TypeError`` will be raised. + Note: + To ensure all layers can be loaded from the checkpoint, this function will call + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` directly after instantiating the + model if this hook is overridden in your LightningModule. However, note that ``load_from_checkpoint`` does + not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this + case, consider loading through the Trainer via ``.fit(ckpt_path=...)``. + Example:: # load weights without mapping ... diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index 0dd57bf5bc..56e5afbe46 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -38,6 +38,7 @@ from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAA from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning.pytorch.utilities.migration import pl_legacy_patch from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint +from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.parsing import parse_class_init_keys from lightning.pytorch.utilities.rank_zero import rank_zero_warn @@ -157,6 +158,9 @@ def _load_state( obj = cls(**_cls_kwargs) if isinstance(obj, pl.LightningModule): + if is_overridden("configure_model", obj): + obj.configure_model() + # give model a chance to load something obj.on_load_checkpoint(checkpoint) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 316c5a6cd4..5794f790f7 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -877,13 +877,22 @@ def test_trainer_datamodule_hook_system(tmpdir): assert called == expected -def test_load_from_checkpoint_hook_calls(tmpdir): +@pytest.mark.parametrize("override_configure_model", [True, False]) +def test_load_from_checkpoint_hook_calls(override_configure_model, tmpdir): class CustomHookedDataModule(HookedDataModule): def state_dict(self): return {"foo": "bar"} + class CustomHookedModel(HookedModel): + pass + + if not override_configure_model: + CustomHookedModel.configure_model = None + lm_called, ldm_called = [], [] - model = HookedModel(lm_called) + model = CustomHookedModel(lm_called) + assert is_overridden("configure_model", model) == override_configure_model + datamodule = CustomHookedDataModule(ldm_called) trainer = Trainer() trainer.strategy.connect(model) @@ -908,9 +917,12 @@ def test_load_from_checkpoint_hook_calls(tmpdir): assert ldm_called == [{"name": "state_dict"}] lm_called, ldm_called = [], [] - _ = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called) + _ = CustomHookedModel.load_from_checkpoint(ckpt_path, called=lm_called) _ = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called) - assert lm_called == [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}] + + expected_lm_called = [{"name": "configure_model"}] if override_configure_model else [] + expected_lm_called += [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}] + assert lm_called == expected_lm_called assert ldm_called == [{"name": "load_state_dict", "args": (saved_ckpt[datamodule_state_dict_key],)}]