diff --git a/src/lightning_lite/strategies/__init__.py b/src/lightning_lite/strategies/__init__.py index cdbcb6bad1..eb02cb01b9 100644 --- a/src/lightning_lite/strategies/__init__.py +++ b/src/lightning_lite/strategies/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401 from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401 from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401 diff --git a/src/lightning_lite/strategies/dp.py b/src/lightning_lite/strategies/dp.py new file mode 100644 index 0000000000..8ecc239356 --- /dev/null +++ b/src/lightning_lite/strategies/dp.py @@ -0,0 +1,84 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Union + +import torch +from torch import Tensor +from torch.nn import DataParallel, Module + +from lightning_lite.accelerators import Accelerator +from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.plugins.precision import Precision +from lightning_lite.strategies.parallel import ParallelStrategy +from lightning_lite.strategies.strategy import TBroadcast, TReduce +from lightning_lite.utilities.apply_func import apply_to_collection +from lightning_lite.utilities.distributed import ReduceOp + + +class DataParallelStrategy(ParallelStrategy): + """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and + each gets a split of the data.""" + + def __init__( + self, + accelerator: Optional[Accelerator] = None, + parallel_devices: Optional[List[torch.device]] = None, + checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[Precision] = None, + ): + super().__init__( + accelerator=accelerator, + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) + + @property + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[0] + + def setup_module(self, module: Module) -> DataParallel: + """Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module.""" + return DataParallel(module=module, device_ids=self.parallel_devices) + + def module_to_device(self, module: Module) -> None: + module.to(self.root_device) + + def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: + # DataParallel handles the transfer of batch to the device + return batch + + def reduce( + self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> TReduce: + def mean(t: Tensor) -> Tensor: + original_dtype = t.dtype + return t.float().mean().to(original_dtype) + + return apply_to_collection(collection, Tensor, mean) + + def barrier(self, *args: Any, **kwargs: Any) -> None: + pass + + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + return obj + + def reduce_boolean_decision(self, decision: bool) -> bool: + return decision + + @classmethod + def register_strategies(cls, strategy_registry: Dict) -> None: + strategy_registry.register("dp", cls, description=cls.__class__.__name__) diff --git a/tests/tests_lite/strategies/test_dp.py b/tests/tests_lite/strategies/test_dp.py new file mode 100644 index 0000000000..12a98d8e46 --- /dev/null +++ b/tests/tests_lite/strategies/test_dp.py @@ -0,0 +1,50 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock +from unittest.mock import Mock + +import torch + +from lightning_lite.strategies import DataParallelStrategy + + +def test_data_parallel_root_device(): + strategy = DataParallelStrategy() + strategy.parallel_devices = [torch.device("cuda", 2), torch.device("cuda", 0), torch.device("cuda", 1)] + assert strategy.root_device == torch.device("cuda", 2) + + +def test_data_parallel_ranks(): + strategy = DataParallelStrategy() + assert strategy.world_size == 1 + assert strategy.local_rank == 0 + assert strategy.global_rank == 0 + assert strategy.is_global_zero + + +@mock.patch("lightning_lite.strategies.dp.DataParallel") +def test_data_parallel_setup_module(data_parallel_mock): + strategy = DataParallelStrategy() + strategy.parallel_devices = [0, 2, 1] + module = torch.nn.Linear(2, 2) + wrapped_module = strategy.setup_module(module) + assert wrapped_module == data_parallel_mock(module=module, device_ids=[0, 2, 1]) + + +def test_data_parallel_module_to_device(): + strategy = DataParallelStrategy() + strategy.parallel_devices = [torch.device("cuda", 2)] + module = Mock() + strategy.module_to_device(module) + module.to.assert_called_with(torch.device("cuda", 2)) diff --git a/tests/tests_lite/strategies/test_registry.py b/tests/tests_lite/strategies/test_registry.py index d94198d571..76a6bea00f 100644 --- a/tests/tests_lite/strategies/test_registry.py +++ b/tests/tests_lite/strategies/test_registry.py @@ -41,6 +41,7 @@ def test_strategy_registry_with_new_strategy(): def test_available_strategies_in_registry(): - assert STRATEGY_REGISTRY.available_strategies() == [ + assert set(STRATEGY_REGISTRY.available_strategies()) == { + "dp", "single_tpu", - ] + }