From 3e8319d4222794849d84359176877d3431817c5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 17 Dec 2022 20:00:37 +0100 Subject: [PATCH] Remove redundant `find_unused_parameters=False` in Lite (#16026) --- src/lightning_lite/connector.py | 2 +- src/lightning_lite/strategies/ddp.py | 17 ----------------- src/lightning_lite/strategies/fairscale.py | 13 ------------- tests/tests_lite/strategies/test_ddp.py | 17 +++++++++++++++++ .../strategies/test_fairscale_integration.py | 15 --------------- tests/tests_lite/strategies/test_registry.py | 6 ------ tests/tests_lite/test_connector.py | 10 ++-------- 7 files changed, 20 insertions(+), 60 deletions(-) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index dc13671321..59251fc1cd 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -402,7 +402,7 @@ class _Connector: # TODO this logic should apply to both str and object config strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag - if strategy_flag in ("ddp_spawn", "ddp_spawn_find_unused_parameters_false") and ( + if strategy_flag == "ddp_spawn" and ( TorchElasticEnvironment.detect() or KubeflowEnvironment.detect() or SLURMEnvironment.detect() diff --git a/src/lightning_lite/strategies/ddp.py b/src/lightning_lite/strategies/ddp.py index 38f7a0110a..9cf6963a1f 100644 --- a/src/lightning_lite/strategies/ddp.py +++ b/src/lightning_lite/strategies/ddp.py @@ -43,9 +43,7 @@ from lightning_lite.utilities.rank_zero import rank_zero_only _DDP_FORK_ALIASES = ( "ddp_fork", - "ddp_fork_find_unused_parameters_false", "ddp_notebook", - "ddp_notebook_find_unused_parameters_false", ) @@ -177,21 +175,6 @@ class DDPStrategy(ParallelStrategy): start_method=start_method, ) - entries = ( - ("ddp_find_unused_parameters_false", "popen"), - ("ddp_spawn_find_unused_parameters_false", "spawn"), - ("ddp_fork_find_unused_parameters_false", "fork"), - ("ddp_notebook_find_unused_parameters_false", "fork"), - ) - for name, start_method in entries: - strategy_registry.register( - name, - cls, - description=f"DDP strategy with `find_unused_parameters` as False and `start_method={start_method!r}`", - find_unused_parameters=False, - start_method=start_method, - ) - def _setup_distributed(self) -> None: self._set_world_ranks() rank_zero_only.rank = self.global_rank diff --git a/src/lightning_lite/strategies/fairscale.py b/src/lightning_lite/strategies/fairscale.py index dd2a804d74..b918d99a3e 100644 --- a/src/lightning_lite/strategies/fairscale.py +++ b/src/lightning_lite/strategies/fairscale.py @@ -103,24 +103,11 @@ class DDPShardedStrategy(DDPStrategy): @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: - strategy_registry.register( - "ddp_sharded_find_unused_parameters_false", - cls, - description="DDP Sharded Strategy with `find_unused_parameters` as False", - find_unused_parameters=False, - ) strategy_registry.register( "ddp_sharded", cls, description=cls.__class__.__name__, ) - strategy_registry.register( - "ddp_sharded_spawn_find_unused_parameters_false", - cls, - description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False", - find_unused_parameters=False, - start_method="spawn", - ) strategy_registry.register("ddp_sharded_spawn", cls, description=cls.__class__.__name__, start_method="spawn") diff --git a/tests/tests_lite/strategies/test_ddp.py b/tests/tests_lite/strategies/test_ddp.py index 80bbc6dd61..2182906392 100644 --- a/tests/tests_lite/strategies/test_ddp.py +++ b/tests/tests_lite/strategies/test_ddp.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import MagicMock, Mock import pytest @@ -62,3 +63,19 @@ def test_ddp_no_backward_sync(): pass module.no_sync.assert_called_once() + + +@mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel") +def test_ddp_extra_kwargs(ddp_mock): + """Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel + wrapper.""" + module = torch.nn.Linear(1, 1) + strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")]) + strategy.setup_module(module) + ddp_mock.assert_called_with(module=module, device_ids=None) + + ddp_mock.reset_mock() + + strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")], find_unused_parameters=True) + strategy.setup_module(module) + ddp_mock.assert_called_with(module=module, device_ids=None, find_unused_parameters=True) diff --git a/tests/tests_lite/strategies/test_fairscale_integration.py b/tests/tests_lite/strategies/test_fairscale_integration.py index e2e05b9424..3bcafdfb93 100644 --- a/tests/tests_lite/strategies/test_fairscale_integration.py +++ b/tests/tests_lite/strategies/test_fairscale_integration.py @@ -69,18 +69,3 @@ def test_fairscale_multi_process_checkpoint_state_consolidation(with_fairscale_o weights is identical to the saved one.""" lite = ShardedSaveAndLoad(strategy=strategy, accelerator=accelerator, devices=2) lite.run(tmpdir, with_fairscale_oss=with_fairscale_oss) - - -@pytest.mark.parametrize( - "strategy, expected_find_unused_parameters", - [ - ("ddp_sharded", None), - ("ddp_sharded_find_unused_parameters_false", False), - ("ddp_sharded_spawn", None), - ("ddp_sharded_spawn_find_unused_parameters_false", False), - ], -) -def test_fairscale_find_unused_parameters_from_registry(strategy, expected_find_unused_parameters): - lite = BoringLite(strategy=strategy) - if expected_find_unused_parameters is not None: - assert lite._strategy._ddp_kwargs["find_unused_parameters"] is False diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index 81a49eec08..d6871804a1 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -43,9 +43,7 @@ def test_strategy_registry_with_new_strategy(): def test_available_strategies_in_registry(): expected = { - "ddp_sharded_find_unused_parameters_false", "ddp_sharded", - "ddp_find_unused_parameters_false", "ddp", "deepspeed", "deepspeed_stage_1", @@ -54,14 +52,10 @@ def test_available_strategies_in_registry(): "deepspeed_stage_3", "deepspeed_stage_3_offload", "deepspeed_stage_3_offload_nvme", - "ddp_sharded_spawn_find_unused_parameters_false", "ddp_sharded_spawn", "ddp_spawn", "ddp_fork", "ddp_notebook", - "ddp_spawn_find_unused_parameters_false", - "ddp_fork_find_unused_parameters_false", - "ddp_notebook_find_unused_parameters_false", "single_tpu", "tpu_spawn", "xla", diff --git a/tests/tests_lite/test_connector.py b/tests/tests_lite/test_connector.py index 4230763463..f447e720a6 100644 --- a/tests/tests_lite/test_connector.py +++ b/tests/tests_lite/test_connector.py @@ -81,7 +81,7 @@ def test_strategy_choice_ddp_on_cpu(): def _test_strategy_choice_ddp_and_cpu(ddp_strategy_class): connector = _Connector( - strategy=ddp_strategy_class(find_unused_parameters=True), + strategy=ddp_strategy_class(), accelerator="cpu", devices=2, ) @@ -379,9 +379,7 @@ def test_invalid_strategy_choice(): ["strategy", "strategy_class"], [ ("ddp_spawn", DDPStrategy), - ("ddp_spawn_find_unused_parameters_false", DDPStrategy), ("ddp", DDPStrategy), - ("ddp_find_unused_parameters_false", DDPStrategy), ], ) def test_strategy_choice_cpu_str(strategy, strategy_class): @@ -394,9 +392,7 @@ def test_strategy_choice_cpu_str(strategy, strategy_class): ["strategy", "strategy_class"], [ ("ddp_spawn", DDPStrategy), - ("ddp_spawn_find_unused_parameters_false", DDPStrategy), ("ddp", DDPStrategy), - ("ddp_find_unused_parameters_false", DDPStrategy), ("dp", DataParallelStrategy), ("ddp_sharded", DDPShardedStrategy), ("ddp_sharded_spawn", DDPShardedStrategy), @@ -780,9 +776,7 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin assert isinstance(connector.precision, plugin_cls) -@pytest.mark.parametrize( - ["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("DDP_FIND_UNUSED_PARAMETERS_FALSE", DDPStrategy)] -) +@pytest.mark.parametrize(["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("Ddp", DDPStrategy)]) def test_strategy_str_passed_being_case_insensitive(strategy, strategy_cls): connector = _Connector(strategy=strategy) assert isinstance(connector.strategy, strategy_cls)