diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index d0296b2618..63f019dfe8 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -330,3 +330,12 @@ class DDPPlugin(ParallelPlugin): def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True + + @classmethod + def register_plugins(cls, plugin_registry: Dict) -> None: + plugin_registry.register( + "ddp_find_unused_parameters_false", + cls, + description="DDP Plugin with `find_unused_parameters` as False", + find_unused_parameters=False + ) diff --git a/tests/plugins/test_plugins_registry.py b/tests/plugins/test_plugins_registry.py index c0e112bb6f..8ccba40013 100644 --- a/tests/plugins/test_plugins_registry.py +++ b/tests/plugins/test_plugins_registry.py @@ -14,8 +14,7 @@ import pytest from pytorch_lightning import Trainer -from pytorch_lightning.plugins import TrainingTypePluginsRegistry -from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin +from pytorch_lightning.plugins import DDPPlugin, DeepSpeedPlugin, TrainingTypePluginsRegistry from tests.helpers.runif import RunIf @@ -75,7 +74,7 @@ def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init @RunIf(deepspeed=True) @pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"]) -def test_training_type_plugins_registry_with_trainer(tmpdir, plugin): +def test_deepspeed_training_type_plugins_registry_with_trainer(tmpdir, plugin): trainer = Trainer( default_root_dir=tmpdir, @@ -84,3 +83,13 @@ def test_training_type_plugins_registry_with_trainer(tmpdir, plugin): ) assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin) + + +def test_ddp_training_type_plugins_registry_with_trainer(tmpdir): + + trainer = Trainer( + default_root_dir=tmpdir, + plugins="ddp_find_unused_parameters_false", + ) + + assert isinstance(trainer.training_type_plugin, DDPPlugin)