Deprecate `on_configure_sharded_model` callback hook for v1.6 (#11627)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Piyush Hirapara 2022-02-03 07:59:26 +05:30 committed by GitHub
parent 6586dd23b7
commit 72f0e5bfae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 1 deletions

View File

@ -290,6 +290,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
- Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627))
### Removed
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

View File

@ -57,7 +57,12 @@ class Callback:
return f"{self.__class__.__qualname__}{repr(kwargs)}"
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before configure sharded model."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use `setup()` instead.
Called before configure sharded model.
"""
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before accelerator is being setup."""

View File

@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_on_init_start_end(trainer)
# TODO: Delete _check_on_hpc_hooks in v1.8
_check_on_hpc_hooks(model)
# TODO: Remove this in v1.8
_check_on_configure_sharded_model(trainer)
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
@ -322,3 +324,12 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None:
"Method `LightningModule.on_hpc_load` is deprecated in v1.6 and"
" will be removed in v1.8. Please use `LightningModule.on_load_checkpoint` instead."
)
def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None:
for callback in trainer.callbacks:
if is_overridden(method_name="on_configure_sharded_model", instance=callback):
rank_zero_deprecation(
"The `on_configure_sharded_model` callback hook was deprecated in"
" v1.6 and will be removed in v1.8. Use `setup()` instead."
)

View File

@ -351,3 +351,24 @@ def test_v1_8_0_deprecated_lightning_optimizers():
match="Trainer.lightning_optimizers` is deprecated in v1.6 and will be removed in v1.8"
):
assert trainer.lightning_optimizers == {}
def test_v1_8_0_on_configure_sharded_model(tmpdir):
class TestCallback(Callback):
def on_configure_sharded_model(self, trainer, model):
print("Configuring sharded model")
model = BoringModel()
trainer = Trainer(
callbacks=[TestCallback()],
max_epochs=1,
fast_dev_run=True,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8."
):
trainer.fit(model)