Deprecate `on_configure_sharded_model` callback hook for v1.6 (#11627)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
6586dd23b7
commit
72f0e5bfae
|
@ -290,6 +290,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
|
||||
|
||||
|
||||
- Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
|
||||
|
|
|
@ -57,7 +57,12 @@ class Callback:
|
|||
return f"{self.__class__.__qualname__}{repr(kwargs)}"
|
||||
|
||||
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called before configure sharded model."""
|
||||
r"""
|
||||
.. deprecated:: v1.6
|
||||
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use `setup()` instead.
|
||||
|
||||
Called before configure sharded model.
|
||||
"""
|
||||
|
||||
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called before accelerator is being setup."""
|
||||
|
|
|
@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
|
|||
_check_on_init_start_end(trainer)
|
||||
# TODO: Delete _check_on_hpc_hooks in v1.8
|
||||
_check_on_hpc_hooks(model)
|
||||
# TODO: Remove this in v1.8
|
||||
_check_on_configure_sharded_model(trainer)
|
||||
|
||||
|
||||
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
|
||||
|
@ -322,3 +324,12 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None:
|
|||
"Method `LightningModule.on_hpc_load` is deprecated in v1.6 and"
|
||||
" will be removed in v1.8. Please use `LightningModule.on_load_checkpoint` instead."
|
||||
)
|
||||
|
||||
|
||||
def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None:
|
||||
for callback in trainer.callbacks:
|
||||
if is_overridden(method_name="on_configure_sharded_model", instance=callback):
|
||||
rank_zero_deprecation(
|
||||
"The `on_configure_sharded_model` callback hook was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `setup()` instead."
|
||||
)
|
||||
|
|
|
@ -351,3 +351,24 @@ def test_v1_8_0_deprecated_lightning_optimizers():
|
|||
match="Trainer.lightning_optimizers` is deprecated in v1.6 and will be removed in v1.8"
|
||||
):
|
||||
assert trainer.lightning_optimizers == {}
|
||||
|
||||
|
||||
def test_v1_8_0_on_configure_sharded_model(tmpdir):
|
||||
class TestCallback(Callback):
|
||||
def on_configure_sharded_model(self, trainer, model):
|
||||
print("Configuring sharded model")
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=[TestCallback()],
|
||||
max_epochs=1,
|
||||
fast_dev_run=True,
|
||||
enable_progress_bar=False,
|
||||
logger=False,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
with pytest.deprecated_call(
|
||||
match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8."
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue