Remove deprecated on_load/save_checkpoint behavior (#14835)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2022-10-10 07:08:13 -04:00 committed by GitHub
parent 0b04aa879f
commit 8f90084059
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 87 additions and 188 deletions

View File

@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- HPC checkpoints are now loaded automatically only in slurm environment when no specific value for `ckpt_path` has been set ([#14911](https://github.com/Lightning-AI/lightning/pull/14911))
- The `Callback.on_load_checkpoint` now gets the full checkpoint dictionary and the `callback_state` argument was renamed `checkpoint` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))
### Deprecated
- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
@ -302,6 +305,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `LightningDataModule.on_save/load_checkpoint` hooks ([#14909](https://github.com/Lightning-AI/lightning/pull/14909))
- Removed support for returning a value in `Callback.on_save_checkpoint` in favor of implementing `Callback.state_dict` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))
### Fixed
- Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822))

View File

@ -277,7 +277,7 @@ class Callback:
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Optional[dict]:
) -> None:
r"""
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
@ -285,18 +285,10 @@ class Callback:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
checkpoint: the checkpoint dictionary that will be saved.
Returns:
None or the callback state. Support for returning callback state will be removed in v1.8.
.. deprecated:: v1.6
Returning a value from this method was deprecated in v1.6 and will be removed in v1.8.
Implement ``Callback.state_dict`` instead to return state.
In v1.8 ``Callback.on_save_checkpoint`` can only return None.
"""
def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
r"""
Called when loading a model checkpoint, use to reload state.
@ -304,18 +296,7 @@ class Callback:
Args:
trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance.
callback_state: the callback state returned by ``on_save_checkpoint``.
Note:
The ``on_load_checkpoint`` won't be called with an undefined state.
If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
.. deprecated:: v1.6
This callback hook will change its signature and behavior in v1.8.
If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead.
In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded
checkpoint dictionary instead of only the callback state from the checkpoint.
checkpoint: the full checkpoint dictionary that got loaded by the Trainer.
"""
def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None:

View File

@ -423,9 +423,7 @@ class ModelPruning(Callback):
return apply_to_collection(state_dict, Tensor, move_to_cpu)
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]
) -> Optional[dict]:
def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None:
if self._make_pruning_permanent:
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")
# manually prune the weights so training can keep going with the same buffers

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import pytorch_lightning as pl
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators.ipu import IPUAccelerator
@ -220,12 +222,16 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
"The `on_before_accelerator_backend_setup` callback hook was deprecated in"
" v1.6 and will be removed in v1.8. Use `setup()` instead."
)
if is_overridden(method_name="on_load_checkpoint", instance=callback):
rank_zero_deprecation(
f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8."
has_legacy_argument = "callback_state" in inspect.signature(callback.on_load_checkpoint).parameters
if is_overridden(method_name="on_load_checkpoint", instance=callback) and has_legacy_argument:
# TODO: Remove this error message in v2.0
raise RuntimeError(
f"`{callback.__class__.__name__}.on_load_checkpoint` has changed its signature and behavior in v1.8."
" If you wish to load the state of the callback, use `load_state_dict` instead."
" In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded"
" checkpoint dictionary instead of callback state."
" As of 1.8, `on_load_checkpoint(..., checkpoint)` receives the entire loaded"
" checkpoint dictionary instead of the callback state. To continue using this hook and avoid this error"
" message, rename the `callback_state` argument to `checkpoint`."
)
for hook, alternative_hook in (

View File

@ -1354,11 +1354,7 @@ class Trainer:
return callback_state_dicts
def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.
Will be removed in v1.8: If state is returned, we insert the callback state into
``checkpoint["callbacks"][Callback.state_key]``. It overrides ``state_dict`` if already present.
"""
"""Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook."""
pl_module = self.lightning_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
@ -1367,13 +1363,13 @@ class Trainer:
for callback in self.callbacks:
with self.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"):
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
rank_zero_deprecation(
f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` is deprecated in v1.6"
" and will be removed in v1.8. Please override `Callback.state_dict`"
" to return state to be saved."
if state is not None:
# TODO: Remove this error message in v2.0
raise ValueError(
f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` was deprecated in v1.6"
f" and is no longer supported as of v1.8. Please override `Callback.state_dict` to return state"
f" to be saved."
)
checkpoint["callbacks"][callback.state_key] = state
if pl_module:
# restore current_fx when nested context
@ -1406,11 +1402,8 @@ class Trainer:
)
for callback in self.callbacks:
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
callback.on_load_checkpoint(self, self.lightning_module, state)
with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"):
callback.on_load_checkpoint(self, self.lightning_module, checkpoint)
if pl_module:
# restore current_fx when nested context

View File

@ -133,44 +133,6 @@ def test_resume_callback_state_saved_by_type_stateful(tmpdir):
assert callback.state == 111
class OldStatefulCallbackHooks(Callback):
def __init__(self, state):
self.state = state
@property
def state_key(self):
return type(self)
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"state": self.state}
def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state = callback_state["state"]
def test_resume_callback_state_saved_by_type_hooks(tmpdir):
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using deprecated
on_save/load_checkpoint signatures."""
# TODO: remove old on_save/load_checkpoint signature support in v1.8
# in favor of Stateful and new on_save/load_checkpoint signatures
# on_save_checkpoint() -> dict, on_load_checkpoint(callback_state)
# will become
# on_save_checkpoint() -> None and on_load_checkpoint(checkpoint)
model = BoringModel()
callback = OldStatefulCallbackHooks(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
with pytest.deprecated_call():
trainer.fit(model)
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
assert ckpt_path.exists()
callback = OldStatefulCallbackHooks(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback])
with pytest.deprecated_call():
trainer.fit(model, ckpt_path=ckpt_path)
assert callback.state == 111
def test_resume_incomplete_callbacks_list_warning(tmpdir):
model = BoringModel()
callback0 = ModelCheckpoint(monitor="epoch")
@ -198,49 +160,3 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir):
)
with no_warning_call(UserWarning, match="Please add the following callbacks:"):
trainer.fit(model, ckpt_path=ckpt_path)
class AllStatefulCallback(Callback):
def __init__(self, state):
self.state = state
@property
def state_key(self):
return type(self)
def state_dict(self):
return {"new_state": self.state}
def load_state_dict(self, state_dict):
assert state_dict == {"old_state_precedence": 10}
self.state = state_dict["old_state_precedence"]
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"old_state_precedence": 10}
def on_load_checkpoint(self, trainer, pl_module, callback_state):
assert callback_state == {"old_state_precedence": 10}
self.old_state_precedence = callback_state["old_state_precedence"]
def test_resume_callback_state_all(tmpdir):
"""Test on_save/load_checkpoint state precedence over state_dict/load_state_dict until v1.8 removal."""
# TODO: remove old on_save/load_checkpoint signature support in v1.8
# in favor of Stateful and new on_save/load_checkpoint signatures
# on_save_checkpoint() -> dict, on_load_checkpoint(callback_state)
# will become
# on_save_checkpoint() -> None and on_load_checkpoint(checkpoint)
model = BoringModel()
callback = AllStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
with pytest.deprecated_call():
trainer.fit(model)
ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
assert ckpt_path.exists()
callback = AllStatefulCallback(state=222)
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback])
with pytest.deprecated_call():
trainer.fit(model, ckpt_path=ckpt_path)
assert callback.state == 10
assert callback.old_state_precedence == 10

View File

@ -36,7 +36,6 @@ def test_lambda_call(tmpdir):
hooks = get_members(Callback) - {"state_dict", "load_state_dict"}
hooks_args = {h: partial(call, h) for h in hooks}
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")]
model = CustomModel()

View File

@ -200,56 +200,3 @@ def test_deprecated_mc_save_checkpoint():
match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6"
):
mc.save_checkpoint(trainer)
def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir):
class TestCallbackLoadHook(Callback):
def on_load_checkpoint(self, trainer, pl_module, callback_state):
print("overriding on_load_checkpoint")
model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackLoadHook()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="`TestCallbackLoadHook.on_load_checkpoint` will change its signature and behavior in v1.8."
" If you wish to load the state of the callback, use `load_state_dict` instead."
r" In v1.8 `on_load_checkpoint\(..., checkpoint\)` will receive the entire loaded"
" checkpoint dictionary instead of callback state."
):
trainer.fit(model)
def test_v1_8_0_callback_on_save_checkpoint_hook(tmpdir):
class TestCallbackSaveHookReturn(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"returning": "on_save_checkpoint"}
class TestCallbackSaveHookOverride(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
print("overriding without returning")
model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackSaveHookReturn()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
trainer.fit(model)
with pytest.deprecated_call(
match="Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` is deprecated in v1.6"
" and will be removed in v1.8. Please override `Callback.state_dict`"
" to return state to be saved."
):
trainer.save_checkpoint(tmpdir + "/path.ckpt")
trainer.callbacks = [TestCallbackSaveHookOverride()]
trainer.save_checkpoint(tmpdir + "/pathok.ckpt")

View File

@ -17,7 +17,7 @@ from unittest import mock
import pytest
import pytorch_lightning
from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback
from tests_pytorch.helpers.runif import RunIf
@ -46,7 +46,7 @@ def test_v2_0_0_deprecated_ipus(_, monkeypatch):
_ = Trainer(ipus=4)
def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
def test_v2_0_0_resume_from_checkpoint_trainer_constructor(tmpdir):
# test resume_from_checkpoint still works until v2.0 deprecation
model = BoringModel()
callback = OldStatefulCallback(state=111)
@ -84,3 +84,55 @@ def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir):
trainer = Trainer(resume_from_checkpoint="trainer_arg_path")
with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."):
trainer.fit(model, ckpt_path="fit_arg_ckpt_path")
def test_v2_0_0_callback_on_load_checkpoint_hook(tmpdir):
class TestCallbackLoadHook(Callback):
def on_load_checkpoint(self, trainer, pl_module, callback_state):
print("overriding on_load_checkpoint")
model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackLoadHook()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.raises(
RuntimeError, match="`TestCallbackLoadHook.on_load_checkpoint` has changed its signature and behavior in v1.8."
):
trainer.fit(model)
def test_v2_0_0_callback_on_save_checkpoint_hook(tmpdir):
class TestCallbackSaveHookReturn(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return {"returning": "on_save_checkpoint"}
class TestCallbackSaveHookOverride(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
print("overriding without returning")
model = BoringModel()
trainer = Trainer(
callbacks=[TestCallbackSaveHookReturn()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
trainer.fit(model)
with pytest.raises(
ValueError,
match=(
"Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` was deprecated in v1.6 and is"
" no longer supported as of v1.8"
),
):
trainer.save_checkpoint(tmpdir + "/path.ckpt")
trainer.callbacks = [TestCallbackSaveHookOverride()]
trainer.save_checkpoint(tmpdir + "/pathok.ckpt")

View File

@ -637,7 +637,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
dict(name="Callback.on_load_checkpoint", args=(trainer, model, loaded_ckpt)),
dict(name="Callback.load_state_dict", args=({"foo": True},)),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
@ -726,7 +726,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
dict(name="setup", kwargs=dict(stage="fit")),
dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
dict(name="Callback.on_load_checkpoint", args=(trainer, model, loaded_ckpt)),
dict(name="Callback.load_state_dict", args=({"foo": True},)),
dict(name="configure_sharded_model"),
dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),