Call `configure_model()` in `LM.load_from_checkpoint()` (#19036)

This commit is contained in:
Adrian Wälchli 2023-11-21 15:44:18 +01:00 committed by GitHub
parent aebac09397
commit 49caddde6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 5 deletions

View File

@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `Trainer.fit()` loop no longer calls `LightningModule.train()` at the start; it now preserves the user's configuration of frozen layers ([#18951](https://github.com/Lightning-AI/lightning/pull/18951))
- The `LightningModule.load_from_checkpoint()` function now calls `.configure_model()` on the model if it is overridden, to ensure all layers can be loaded from the checkpoint ([#19036](https://github.com/Lightning-AI/lightning/pull/19036))
### Deprecated
- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))

View File

@ -336,7 +336,8 @@ class ModelHooks:
:meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
implementation of this hook is idempotent.
implementation of this hook is **idempotent**, i.e., after the first time the hook is called, subsequent calls
to it should be a no-op.
"""

View File

@ -1519,6 +1519,13 @@ class LightningModule(
**class** to call it instead of the :class:`LightningModule` instance, or a
``TypeError`` will be raised.
Note:
To ensure all layers can be loaded from the checkpoint, this function will call
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` directly after instantiating the
model if this hook is overridden in your LightningModule. However, note that ``load_from_checkpoint`` does
not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this
case, consider loading through the Trainer via ``.fit(ckpt_path=...)``.
Example::
# load weights without mapping ...

View File

@ -38,6 +38,7 @@ from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAA
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.migration import pl_legacy_patch
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.parsing import parse_class_init_keys
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
@ -157,6 +158,9 @@ def _load_state(
obj = cls(**_cls_kwargs)
if isinstance(obj, pl.LightningModule):
if is_overridden("configure_model", obj):
obj.configure_model()
# give model a chance to load something
obj.on_load_checkpoint(checkpoint)

View File

@ -877,13 +877,22 @@ def test_trainer_datamodule_hook_system(tmpdir):
assert called == expected
def test_load_from_checkpoint_hook_calls(tmpdir):
@pytest.mark.parametrize("override_configure_model", [True, False])
def test_load_from_checkpoint_hook_calls(override_configure_model, tmpdir):
class CustomHookedDataModule(HookedDataModule):
def state_dict(self):
return {"foo": "bar"}
class CustomHookedModel(HookedModel):
pass
if not override_configure_model:
CustomHookedModel.configure_model = None
lm_called, ldm_called = [], []
model = HookedModel(lm_called)
model = CustomHookedModel(lm_called)
assert is_overridden("configure_model", model) == override_configure_model
datamodule = CustomHookedDataModule(ldm_called)
trainer = Trainer()
trainer.strategy.connect(model)
@ -908,9 +917,12 @@ def test_load_from_checkpoint_hook_calls(tmpdir):
assert ldm_called == [{"name": "state_dict"}]
lm_called, ldm_called = [], []
_ = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
_ = CustomHookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
_ = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called)
assert lm_called == [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}]
expected_lm_called = [{"name": "configure_model"}] if override_configure_model else []
expected_lm_called += [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}]
assert lm_called == expected_lm_called
assert ldm_called == [{"name": "load_state_dict", "args": (saved_ckpt[datamodule_state_dict_key],)}]