Remove deprecated on_load/save_checkpoint behavior (#14835)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
0b04aa879f
commit
8f90084059
|
@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- HPC checkpoints are now loaded automatically only in slurm environment when no specific value for `ckpt_path` has been set ([#14911](https://github.com/Lightning-AI/lightning/pull/14911))
|
||||
|
||||
|
||||
- The `Callback.on_load_checkpoint` now gets the full checkpoint dictionary and the `callback_state` argument was renamed `checkpoint` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
|
||||
|
@ -302,6 +305,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed the deprecated `LightningDataModule.on_save/load_checkpoint` hooks ([#14909](https://github.com/Lightning-AI/lightning/pull/14909))
|
||||
|
||||
|
||||
- Removed support for returning a value in `Callback.on_save_checkpoint` in favor of implementing `Callback.state_dict` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))
|
||||
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822))
|
||||
|
|
|
@ -277,7 +277,7 @@ class Callback:
|
|||
|
||||
def on_save_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||
) -> Optional[dict]:
|
||||
) -> None:
|
||||
r"""
|
||||
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
|
||||
|
||||
|
@ -285,18 +285,10 @@ class Callback:
|
|||
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
|
||||
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
|
||||
checkpoint: the checkpoint dictionary that will be saved.
|
||||
|
||||
Returns:
|
||||
None or the callback state. Support for returning callback state will be removed in v1.8.
|
||||
|
||||
.. deprecated:: v1.6
|
||||
Returning a value from this method was deprecated in v1.6 and will be removed in v1.8.
|
||||
Implement ``Callback.state_dict`` instead to return state.
|
||||
In v1.8 ``Callback.on_save_checkpoint`` can only return None.
|
||||
"""
|
||||
|
||||
def on_load_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||
) -> None:
|
||||
r"""
|
||||
Called when loading a model checkpoint, use to reload state.
|
||||
|
@ -304,18 +296,7 @@ class Callback:
|
|||
Args:
|
||||
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
|
||||
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
|
||||
callback_state: the callback state returned by ``on_save_checkpoint``.
|
||||
|
||||
Note:
|
||||
The ``on_load_checkpoint`` won't be called with an undefined state.
|
||||
If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
|
||||
you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
|
||||
|
||||
.. deprecated:: v1.6
|
||||
This callback hook will change its signature and behavior in v1.8.
|
||||
If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead.
|
||||
In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded
|
||||
checkpoint dictionary instead of only the callback state from the checkpoint.
|
||||
checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
|
||||
"""
|
||||
|
||||
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None:
|
||||
|
|
|
@ -423,9 +423,7 @@ class ModelPruning(Callback):
|
|||
|
||||
return apply_to_collection(state_dict, Tensor, move_to_cpu)
|
||||
|
||||
def on_save_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]
|
||||
) -> Optional[dict]:
|
||||
def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None:
|
||||
if self._make_pruning_permanent:
|
||||
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")
|
||||
# manually prune the weights so training can keep going with the same buffers
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.accelerators.ipu import IPUAccelerator
|
||||
|
@ -220,12 +222,16 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
|
|||
"The `on_before_accelerator_backend_setup` callback hook was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `setup()` instead."
|
||||
)
|
||||
if is_overridden(method_name="on_load_checkpoint", instance=callback):
|
||||
rank_zero_deprecation(
|
||||
f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8."
|
||||
|
||||
has_legacy_argument = "callback_state" in inspect.signature(callback.on_load_checkpoint).parameters
|
||||
if is_overridden(method_name="on_load_checkpoint", instance=callback) and has_legacy_argument:
|
||||
# TODO: Remove this error message in v2.0
|
||||
raise RuntimeError(
|
||||
f"`{callback.__class__.__name__}.on_load_checkpoint` has changed its signature and behavior in v1.8."
|
||||
" If you wish to load the state of the callback, use `load_state_dict` instead."
|
||||
" In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded"
|
||||
" checkpoint dictionary instead of callback state."
|
||||
" As of 1.8, `on_load_checkpoint(..., checkpoint)` receives the entire loaded"
|
||||
" checkpoint dictionary instead of the callback state. To continue using this hook and avoid this error"
|
||||
" message, rename the `callback_state` argument to `checkpoint`."
|
||||
)
|
||||
|
||||
for hook, alternative_hook in (
|
||||
|
|
|
@ -1354,11 +1354,7 @@ class Trainer:
|
|||
return callback_state_dicts
|
||||
|
||||
def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
||||
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.
|
||||
|
||||
Will be removed in v1.8: If state is returned, we insert the callback state into
|
||||
``checkpoint["callbacks"][Callback.state_key]``. It overrides ``state_dict`` if already present.
|
||||
"""
|
||||
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook."""
|
||||
pl_module = self.lightning_module
|
||||
if pl_module:
|
||||
prev_fx_name = pl_module._current_fx_name
|
||||
|
@ -1367,13 +1363,13 @@ class Trainer:
|
|||
for callback in self.callbacks:
|
||||
with self.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"):
|
||||
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
|
||||
if state:
|
||||
rank_zero_deprecation(
|
||||
f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` is deprecated in v1.6"
|
||||
" and will be removed in v1.8. Please override `Callback.state_dict`"
|
||||
" to return state to be saved."
|
||||
if state is not None:
|
||||
# TODO: Remove this error message in v2.0
|
||||
raise ValueError(
|
||||
f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` was deprecated in v1.6"
|
||||
f" and is no longer supported as of v1.8. Please override `Callback.state_dict` to return state"
|
||||
f" to be saved."
|
||||
)
|
||||
checkpoint["callbacks"][callback.state_key] = state
|
||||
|
||||
if pl_module:
|
||||
# restore current_fx when nested context
|
||||
|
@ -1406,11 +1402,8 @@ class Trainer:
|
|||
)
|
||||
|
||||
for callback in self.callbacks:
|
||||
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
|
||||
if state:
|
||||
state = deepcopy(state)
|
||||
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
|
||||
callback.on_load_checkpoint(self, self.lightning_module, state)
|
||||
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
|
||||
callback.on_load_checkpoint(self, self.lightning_module, checkpoint)
|
||||
|
||||
if pl_module:
|
||||
# restore current_fx when nested context
|
||||
|
|
|
@ -133,44 +133,6 @@ def test_resume_callback_state_saved_by_type_stateful(tmpdir):
|
|||
assert callback.state == 111
|
||||
|
||||
|
||||
class OldStatefulCallbackHooks(Callback):
|
||||
def __init__(self, state):
|
||||
self.state = state
|
||||
|
||||
@property
|
||||
def state_key(self):
|
||||
return type(self)
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
return {"state": self.state}
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, callback_state):
|
||||
self.state = callback_state["state"]
|
||||
|
||||
|
||||
def test_resume_callback_state_saved_by_type_hooks(tmpdir):
|
||||
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using deprecated
|
||||
on_save/load_checkpoint signatures."""
|
||||
# TODO: remove old on_save/load_checkpoint signature support in v1.8
|
||||
# in favor of Stateful and new on_save/load_checkpoint signatures
|
||||
# on_save_checkpoint() -> dict, on_load_checkpoint(callback_state)
|
||||
# will become
|
||||
# on_save_checkpoint() -> None and on_load_checkpoint(checkpoint)
|
||||
model = BoringModel()
|
||||
callback = OldStatefulCallbackHooks(state=111)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
|
||||
with pytest.deprecated_call():
|
||||
trainer.fit(model)
|
||||
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
|
||||
assert ckpt_path.exists()
|
||||
|
||||
callback = OldStatefulCallbackHooks(state=222)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback])
|
||||
with pytest.deprecated_call():
|
||||
trainer.fit(model, ckpt_path=ckpt_path)
|
||||
assert callback.state == 111
|
||||
|
||||
|
||||
def test_resume_incomplete_callbacks_list_warning(tmpdir):
|
||||
model = BoringModel()
|
||||
callback0 = ModelCheckpoint(monitor="epoch")
|
||||
|
@ -198,49 +160,3 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir):
|
|||
)
|
||||
with no_warning_call(UserWarning, match="Please add the following callbacks:"):
|
||||
trainer.fit(model, ckpt_path=ckpt_path)
|
||||
|
||||
|
||||
class AllStatefulCallback(Callback):
|
||||
def __init__(self, state):
|
||||
self.state = state
|
||||
|
||||
@property
|
||||
def state_key(self):
|
||||
return type(self)
|
||||
|
||||
def state_dict(self):
|
||||
return {"new_state": self.state}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
assert state_dict == {"old_state_precedence": 10}
|
||||
self.state = state_dict["old_state_precedence"]
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
return {"old_state_precedence": 10}
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, callback_state):
|
||||
assert callback_state == {"old_state_precedence": 10}
|
||||
self.old_state_precedence = callback_state["old_state_precedence"]
|
||||
|
||||
|
||||
def test_resume_callback_state_all(tmpdir):
|
||||
"""Test on_save/load_checkpoint state precedence over state_dict/load_state_dict until v1.8 removal."""
|
||||
# TODO: remove old on_save/load_checkpoint signature support in v1.8
|
||||
# in favor of Stateful and new on_save/load_checkpoint signatures
|
||||
# on_save_checkpoint() -> dict, on_load_checkpoint(callback_state)
|
||||
# will become
|
||||
# on_save_checkpoint() -> None and on_load_checkpoint(checkpoint)
|
||||
model = BoringModel()
|
||||
callback = AllStatefulCallback(state=111)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
|
||||
with pytest.deprecated_call():
|
||||
trainer.fit(model)
|
||||
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
|
||||
assert ckpt_path.exists()
|
||||
|
||||
callback = AllStatefulCallback(state=222)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback])
|
||||
with pytest.deprecated_call():
|
||||
trainer.fit(model, ckpt_path=ckpt_path)
|
||||
assert callback.state == 10
|
||||
assert callback.old_state_precedence == 10
|
||||
|
|
|
@ -36,7 +36,6 @@ def test_lambda_call(tmpdir):
|
|||
|
||||
hooks = get_members(Callback) - {"state_dict", "load_state_dict"}
|
||||
hooks_args = {h: partial(call, h) for h in hooks}
|
||||
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")]
|
||||
|
||||
model = CustomModel()
|
||||
|
||||
|
|
|
@ -200,56 +200,3 @@ def test_deprecated_mc_save_checkpoint():
|
|||
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
|
||||
):
|
||||
mc.save_checkpoint(trainer)
|
||||
|
||||
|
||||
def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
|
||||
class TestCallbackLoadHook(Callback):
|
||||
def on_load_checkpoint(self, trainer, pl_module, callback_state):
|
||||
print("overriding on_load_checkpoint")
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
callbacks=[TestCallbackLoadHook()],
|
||||
max_epochs=1,
|
||||
fast_dev_run=True,
|
||||
enable_progress_bar=False,
|
||||
logger=False,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
with pytest.deprecated_call(
|
||||
match="`TestCallbackLoadHook.on_load_checkpoint` will change its signature and behavior in v1.8."
|
||||
" If you wish to load the state of the callback, use `load_state_dict` instead."
|
||||
r" In v1.8 `on_load_checkpoint\(..., checkpoint\)` will receive the entire loaded"
|
||||
" checkpoint dictionary instead of callback state."
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v1_8_0_callback_on_save_checkpoint_hook(tmpdir):
|
||||
class TestCallbackSaveHookReturn(Callback):
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
return {"returning": "on_save_checkpoint"}
|
||||
|
||||
class TestCallbackSaveHookOverride(Callback):
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
print("overriding without returning")
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
callbacks=[TestCallbackSaveHookReturn()],
|
||||
max_epochs=1,
|
||||
fast_dev_run=True,
|
||||
enable_progress_bar=False,
|
||||
logger=False,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
trainer.fit(model)
|
||||
with pytest.deprecated_call(
|
||||
match="Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` is deprecated in v1.6"
|
||||
" and will be removed in v1.8. Please override `Callback.state_dict`"
|
||||
" to return state to be saved."
|
||||
):
|
||||
trainer.save_checkpoint(tmpdir + "/path.ckpt")
|
||||
|
||||
trainer.callbacks = [TestCallbackSaveHookOverride()]
|
||||
trainer.save_checkpoint(tmpdir + "/pathok.ckpt")
|
||||
|
|
|
@ -17,7 +17,7 @@ from unittest import mock
|
|||
import pytest
|
||||
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
@ -46,7 +46,7 @@ def test_v2_0_0_deprecated_ipus(_, monkeypatch):
|
|||
_ = Trainer(ipus=4)
|
||||
|
||||
|
||||
def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
|
||||
def test_v2_0_0_resume_from_checkpoint_trainer_constructor(tmpdir):
|
||||
# test resume_from_checkpoint still works until v2.0 deprecation
|
||||
model = BoringModel()
|
||||
callback = OldStatefulCallback(state=111)
|
||||
|
@ -84,3 +84,55 @@ def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
|
|||
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
|
||||
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
|
||||
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")
|
||||
|
||||
|
||||
def test_v2_0_0_callback_on_load_checkpoint_hook(tmpdir):
|
||||
class TestCallbackLoadHook(Callback):
|
||||
def on_load_checkpoint(self, trainer, pl_module, callback_state):
|
||||
print("overriding on_load_checkpoint")
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
callbacks=[TestCallbackLoadHook()],
|
||||
max_epochs=1,
|
||||
fast_dev_run=True,
|
||||
enable_progress_bar=False,
|
||||
logger=False,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
with pytest.raises(
|
||||
RuntimeError, match="`TestCallbackLoadHook.on_load_checkpoint` has changed its signature and behavior in v1.8."
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v2_0_0_callback_on_save_checkpoint_hook(tmpdir):
|
||||
class TestCallbackSaveHookReturn(Callback):
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
return {"returning": "on_save_checkpoint"}
|
||||
|
||||
class TestCallbackSaveHookOverride(Callback):
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
print("overriding without returning")
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
callbacks=[TestCallbackSaveHookReturn()],
|
||||
max_epochs=1,
|
||||
fast_dev_run=True,
|
||||
enable_progress_bar=False,
|
||||
logger=False,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
trainer.fit(model)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` was deprecated in v1.6 and is"
|
||||
" no longer supported as of v1.8"
|
||||
),
|
||||
):
|
||||
trainer.save_checkpoint(tmpdir + "/path.ckpt")
|
||||
|
||||
trainer.callbacks = [TestCallbackSaveHookOverride()]
|
||||
trainer.save_checkpoint(tmpdir + "/pathok.ckpt")
|
||||
|
|
|
@ -637,7 +637,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
|
|||
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
|
||||
dict(name="setup", kwargs=dict(stage="fit")),
|
||||
dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
|
||||
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
|
||||
dict(name="Callback.on_load_checkpoint", args=(trainer, model, loaded_ckpt)),
|
||||
dict(name="Callback.load_state_dict", args=({"foo": True},)),
|
||||
dict(name="configure_sharded_model"),
|
||||
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
|
||||
|
@ -726,7 +726,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
|
|||
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
|
||||
dict(name="setup", kwargs=dict(stage="fit")),
|
||||
dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
|
||||
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
|
||||
dict(name="Callback.on_load_checkpoint", args=(trainer, model, loaded_ckpt)),
|
||||
dict(name="Callback.load_state_dict", args=({"foo": True},)),
|
||||
dict(name="configure_sharded_model"),
|
||||
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
|
||||
|
|
Loading…
Reference in New Issue