Call `configure_model()` in `LM.load_from_checkpoint()` (#19036)
This commit is contained in:
parent
aebac09397
commit
49caddde6e
|
@ -29,6 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- The `Trainer.fit()` loop no longer calls `LightningModule.train()` at the start; it now preserves the user's configuration of frozen layers ([#18951](https://github.com/Lightning-AI/lightning/pull/18951))
|
||||
|
||||
|
||||
- The `LightningModule.load_from_checkpoint()` function now calls `.configure_model()` on the model if it is overridden, to ensure all layers can be loaded from the checkpoint ([#19036](https://github.com/Lightning-AI/lightning/pull/19036))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated all precision plugin classes under `lightning.pytorch.plugins` with the suffix `Plugin` in the name ([#18840](https://github.com/Lightning-AI/lightning/pull/18840))
|
||||
|
|
|
@ -336,7 +336,8 @@ class ModelHooks:
|
|||
:meth:`~lightning.pytorch.trainer.trainer.Trainer.init_module` context manager.
|
||||
|
||||
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that
|
||||
implementation of this hook is idempotent.
|
||||
implementation of this hook is **idempotent**, i.e., after the first time the hook is called, subsequent calls
|
||||
to it should be a no-op.
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
@ -1519,6 +1519,13 @@ class LightningModule(
|
|||
**class** to call it instead of the :class:`LightningModule` instance, or a
|
||||
``TypeError`` will be raised.
|
||||
|
||||
Note:
|
||||
To ensure all layers can be loaded from the checkpoint, this function will call
|
||||
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` directly after instantiating the
|
||||
model if this hook is overridden in your LightningModule. However, note that ``load_from_checkpoint`` does
|
||||
not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this
|
||||
case, consider loading through the Trainer via ``.fit(ckpt_path=...)``.
|
||||
|
||||
Example::
|
||||
|
||||
# load weights without mapping ...
|
||||
|
|
|
@ -38,6 +38,7 @@ from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAA
|
|||
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from lightning.pytorch.utilities.migration import pl_legacy_patch
|
||||
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.parsing import parse_class_init_keys
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
|
||||
|
||||
|
@ -157,6 +158,9 @@ def _load_state(
|
|||
obj = cls(**_cls_kwargs)
|
||||
|
||||
if isinstance(obj, pl.LightningModule):
|
||||
if is_overridden("configure_model", obj):
|
||||
obj.configure_model()
|
||||
|
||||
# give model a chance to load something
|
||||
obj.on_load_checkpoint(checkpoint)
|
||||
|
||||
|
|
|
@ -877,13 +877,22 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
assert called == expected
|
||||
|
||||
|
||||
def test_load_from_checkpoint_hook_calls(tmpdir):
|
||||
@pytest.mark.parametrize("override_configure_model", [True, False])
|
||||
def test_load_from_checkpoint_hook_calls(override_configure_model, tmpdir):
|
||||
class CustomHookedDataModule(HookedDataModule):
|
||||
def state_dict(self):
|
||||
return {"foo": "bar"}
|
||||
|
||||
class CustomHookedModel(HookedModel):
|
||||
pass
|
||||
|
||||
if not override_configure_model:
|
||||
CustomHookedModel.configure_model = None
|
||||
|
||||
lm_called, ldm_called = [], []
|
||||
model = HookedModel(lm_called)
|
||||
model = CustomHookedModel(lm_called)
|
||||
assert is_overridden("configure_model", model) == override_configure_model
|
||||
|
||||
datamodule = CustomHookedDataModule(ldm_called)
|
||||
trainer = Trainer()
|
||||
trainer.strategy.connect(model)
|
||||
|
@ -908,9 +917,12 @@ def test_load_from_checkpoint_hook_calls(tmpdir):
|
|||
assert ldm_called == [{"name": "state_dict"}]
|
||||
|
||||
lm_called, ldm_called = [], []
|
||||
_ = HookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
|
||||
_ = CustomHookedModel.load_from_checkpoint(ckpt_path, called=lm_called)
|
||||
_ = CustomHookedDataModule.load_from_checkpoint(ckpt_path, called=ldm_called)
|
||||
assert lm_called == [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}]
|
||||
|
||||
expected_lm_called = [{"name": "configure_model"}] if override_configure_model else []
|
||||
expected_lm_called += [{"name": "on_load_checkpoint", "args": ({**saved_ckpt, "hyper_parameters": ANY},)}]
|
||||
assert lm_called == expected_lm_called
|
||||
assert ldm_called == [{"name": "load_state_dict", "args": (saved_ckpt[datamodule_state_dict_key],)}]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue