Remove redundant `find_unused_parameters=False` in Lite (#16026)

This commit is contained in:
Adrian Wälchli 2022-12-17 20:00:37 +01:00 committed by GitHub
parent acd48d3b0d
commit 3e8319d422
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 20 additions and 60 deletions

View File

@ -402,7 +402,7 @@ class _Connector:
# TODO this logic should apply to both str and object config
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag
if strategy_flag in ("ddp_spawn", "ddp_spawn_find_unused_parameters_false") and (
if strategy_flag == "ddp_spawn" and (
TorchElasticEnvironment.detect()
or KubeflowEnvironment.detect()
or SLURMEnvironment.detect()

View File

@ -43,9 +43,7 @@ from lightning_lite.utilities.rank_zero import rank_zero_only
_DDP_FORK_ALIASES = (
"ddp_fork",
"ddp_fork_find_unused_parameters_false",
"ddp_notebook",
"ddp_notebook_find_unused_parameters_false",
)
@ -177,21 +175,6 @@ class DDPStrategy(ParallelStrategy):
start_method=start_method,
)
entries = (
("ddp_find_unused_parameters_false", "popen"),
("ddp_spawn_find_unused_parameters_false", "spawn"),
("ddp_fork_find_unused_parameters_false", "fork"),
("ddp_notebook_find_unused_parameters_false", "fork"),
)
for name, start_method in entries:
strategy_registry.register(
name,
cls,
description=f"DDP strategy with `find_unused_parameters` as False and `start_method={start_method!r}`",
find_unused_parameters=False,
start_method=start_method,
)
def _setup_distributed(self) -> None:
self._set_world_ranks()
rank_zero_only.rank = self.global_rank

View File

@ -103,24 +103,11 @@ class DDPShardedStrategy(DDPStrategy):
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"ddp_sharded_find_unused_parameters_false",
cls,
description="DDP Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
strategy_registry.register(
"ddp_sharded",
cls,
description=cls.__class__.__name__,
)
strategy_registry.register(
"ddp_sharded_spawn_find_unused_parameters_false",
cls,
description="DDP Spawn Sharded Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
start_method="spawn",
)
strategy_registry.register("ddp_sharded_spawn", cls, description=cls.__class__.__name__, start_method="spawn")

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import MagicMock, Mock
import pytest
@ -62,3 +63,19 @@ def test_ddp_no_backward_sync():
pass
module.no_sync.assert_called_once()
@mock.patch("lightning_lite.strategies.ddp.DistributedDataParallel")
def test_ddp_extra_kwargs(ddp_mock):
"""Test that additional kwargs passed to the DDPStrategy get passed down to the DistributedDataParallel
wrapper."""
module = torch.nn.Linear(1, 1)
strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")])
strategy.setup_module(module)
ddp_mock.assert_called_with(module=module, device_ids=None)
ddp_mock.reset_mock()
strategy = DDPStrategy(parallel_devices=[torch.device("cpu"), torch.device("cpu")], find_unused_parameters=True)
strategy.setup_module(module)
ddp_mock.assert_called_with(module=module, device_ids=None, find_unused_parameters=True)

View File

@ -69,18 +69,3 @@ def test_fairscale_multi_process_checkpoint_state_consolidation(with_fairscale_o
weights is identical to the saved one."""
lite = ShardedSaveAndLoad(strategy=strategy, accelerator=accelerator, devices=2)
lite.run(tmpdir, with_fairscale_oss=with_fairscale_oss)
@pytest.mark.parametrize(
"strategy, expected_find_unused_parameters",
[
("ddp_sharded", None),
("ddp_sharded_find_unused_parameters_false", False),
("ddp_sharded_spawn", None),
("ddp_sharded_spawn_find_unused_parameters_false", False),
],
)
def test_fairscale_find_unused_parameters_from_registry(strategy, expected_find_unused_parameters):
lite = BoringLite(strategy=strategy)
if expected_find_unused_parameters is not None:
assert lite._strategy._ddp_kwargs["find_unused_parameters"] is False

View File

@ -43,9 +43,7 @@ def test_strategy_registry_with_new_strategy():
def test_available_strategies_in_registry():
expected = {
"ddp_sharded_find_unused_parameters_false",
"ddp_sharded",
"ddp_find_unused_parameters_false",
"ddp",
"deepspeed",
"deepspeed_stage_1",
@ -54,14 +52,10 @@ def test_available_strategies_in_registry():
"deepspeed_stage_3",
"deepspeed_stage_3_offload",
"deepspeed_stage_3_offload_nvme",
"ddp_sharded_spawn_find_unused_parameters_false",
"ddp_sharded_spawn",
"ddp_spawn",
"ddp_fork",
"ddp_notebook",
"ddp_spawn_find_unused_parameters_false",
"ddp_fork_find_unused_parameters_false",
"ddp_notebook_find_unused_parameters_false",
"single_tpu",
"tpu_spawn",
"xla",

View File

@ -81,7 +81,7 @@ def test_strategy_choice_ddp_on_cpu():
def _test_strategy_choice_ddp_and_cpu(ddp_strategy_class):
connector = _Connector(
strategy=ddp_strategy_class(find_unused_parameters=True),
strategy=ddp_strategy_class(),
accelerator="cpu",
devices=2,
)
@ -379,9 +379,7 @@ def test_invalid_strategy_choice():
["strategy", "strategy_class"],
[
("ddp_spawn", DDPStrategy),
("ddp_spawn_find_unused_parameters_false", DDPStrategy),
("ddp", DDPStrategy),
("ddp_find_unused_parameters_false", DDPStrategy),
],
)
def test_strategy_choice_cpu_str(strategy, strategy_class):
@ -394,9 +392,7 @@ def test_strategy_choice_cpu_str(strategy, strategy_class):
["strategy", "strategy_class"],
[
("ddp_spawn", DDPStrategy),
("ddp_spawn_find_unused_parameters_false", DDPStrategy),
("ddp", DDPStrategy),
("ddp_find_unused_parameters_false", DDPStrategy),
("dp", DataParallelStrategy),
("ddp_sharded", DDPShardedStrategy),
("ddp_sharded_spawn", DDPShardedStrategy),
@ -780,9 +776,7 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
assert isinstance(connector.precision, plugin_cls)
@pytest.mark.parametrize(
["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("DDP_FIND_UNUSED_PARAMETERS_FALSE", DDPStrategy)]
)
@pytest.mark.parametrize(["strategy", "strategy_cls"], [("DDP", DDPStrategy), ("Ddp", DDPStrategy)])
def test_strategy_str_passed_being_case_insensitive(strategy, strategy_cls):
connector = _Connector(strategy=strategy)
assert isinstance(connector.strategy, strategy_cls)