diff --git a/CHANGELOG.md b/CHANGELOG.md index 12e49477a7..b74145a356 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -290,6 +290,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254)) +- Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627)) + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 3302d7643d..2367000747 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -57,7 +57,12 @@ class Callback: return f"{self.__class__.__qualname__}{repr(kwargs)}" def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called before configure sharded model.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use `setup()` instead. + + Called before configure sharded model. + """ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before accelerator is being setup.""" diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index a6467f5452..a71e42f8c9 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: _check_on_init_start_end(trainer) # TODO: Delete _check_on_hpc_hooks in v1.8 _check_on_hpc_hooks(model) + # TODO: Remove this in v1.8 + _check_on_configure_sharded_model(trainer) def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: @@ -322,3 +324,12 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None: "Method `LightningModule.on_hpc_load` is deprecated in v1.6 and" " will be removed in v1.8. Please use `LightningModule.on_load_checkpoint` instead." ) + + +def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None: + for callback in trainer.callbacks: + if is_overridden(method_name="on_configure_sharded_model", instance=callback): + rank_zero_deprecation( + "The `on_configure_sharded_model` callback hook was deprecated in" + " v1.6 and will be removed in v1.8. Use `setup()` instead." + ) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index fa0d982478..55659aeec5 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -351,3 +351,24 @@ def test_v1_8_0_deprecated_lightning_optimizers(): match="Trainer.lightning_optimizers` is deprecated in v1.6 and will be removed in v1.8" ): assert trainer.lightning_optimizers == {} + + +def test_v1_8_0_on_configure_sharded_model(tmpdir): + class TestCallback(Callback): + def on_configure_sharded_model(self, trainer, model): + print("Configuring sharded model") + + model = BoringModel() + + trainer = Trainer( + callbacks=[TestCallback()], + max_epochs=1, + fast_dev_run=True, + enable_progress_bar=False, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8." + ): + trainer.fit(model)