Add ddp_find_unused_parameters_false to Registry (#7224)

This commit is contained in:
Kaushik B 2021-05-05 04:10:00 +05:30 committed by GitHub
parent df579a842a
commit e21b7a62d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 3 deletions

View File

@ -330,3 +330,12 @@ class DDPPlugin(ParallelPlugin):
def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True
@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
plugin_registry.register(
"ddp_find_unused_parameters_false",
cls,
description="DDP Plugin with `find_unused_parameters` as False",
find_unused_parameters=False
)

View File

@ -14,8 +14,7 @@
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import TrainingTypePluginsRegistry
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from pytorch_lightning.plugins import DDPPlugin, DeepSpeedPlugin, TrainingTypePluginsRegistry
from tests.helpers.runif import RunIf
@ -75,7 +74,7 @@ def test_training_type_plugins_registry_with_deepspeed_plugins(plugin_name, init
@RunIf(deepspeed=True)
@pytest.mark.parametrize("plugin", ["deepspeed", "deepspeed_stage_2_offload", "deepspeed_stage_3"])
def test_training_type_plugins_registry_with_trainer(tmpdir, plugin):
def test_deepspeed_training_type_plugins_registry_with_trainer(tmpdir, plugin):
trainer = Trainer(
default_root_dir=tmpdir,
@ -84,3 +83,13 @@ def test_training_type_plugins_registry_with_trainer(tmpdir, plugin):
)
assert isinstance(trainer.training_type_plugin, DeepSpeedPlugin)
def test_ddp_training_type_plugins_registry_with_trainer(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
plugins="ddp_find_unused_parameters_false",
)
assert isinstance(trainer.training_type_plugin, DDPPlugin)