diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 3740e3a8f5..6e75f250f5 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed check for FSDP's flat parameters in all parameter groups ([#17914](https://github.com/Lightning-AI/lightning/pull/17914)) +- Removed the need to call `.launch()` when using the DP-strategy (`strategy="dp"`) ([#17931](https://github.com/Lightning-AI/lightning/pull/17931)) + + - Fixed automatic step tracking in Fabric's CSVLogger ([#17942](https://github.com/Lightning-AI/lightning/pull/17942)) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index c6dcd102d3..0e39bb7833 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -33,7 +33,14 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from lightning.fabric.plugins import Precision # avoid circular imports: # isort: split from lightning.fabric.accelerators.accelerator import Accelerator from lightning.fabric.connector import _Connector, _is_using_cli, _PLUGIN_INPUT, _PRECISION_INPUT -from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy +from lightning.fabric.strategies import ( + DataParallelStrategy, + DeepSpeedStrategy, + FSDPStrategy, + SingleDeviceStrategy, + Strategy, + XLAStrategy, +) from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher from lightning.fabric.strategies.strategy import _Sharded, TBroadcast from lightning.fabric.utilities import move_data_to_device @@ -919,7 +926,7 @@ class Fabric: setattr(self, "run", partial(self._wrap_and_launch, self.run)) def _validate_launched(self) -> None: - if not self._launched and not isinstance(self._strategy, SingleDeviceStrategy): + if not self._launched and not isinstance(self._strategy, (SingleDeviceStrategy, DataParallelStrategy)): raise RuntimeError( "To use Fabric with more than one device, you must call `.launch()` or use the CLI:" " `lightning run model --help`." diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 9a9991eacc..befe22620a 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -27,6 +27,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Samp from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.strategies import ( + DataParallelStrategy, DDPStrategy, DeepSpeedStrategy, ParallelStrategy, @@ -1134,6 +1135,8 @@ def test_verify_launch_called(): assert not fabric._launched fabric._strategy = Mock(spec=SingleDeviceStrategy) fabric._validate_launched() + fabric._strategy = Mock(spec=DataParallelStrategy) + fabric._validate_launched() fabric._strategy = Mock(spec=DDPStrategy) with pytest.raises(RuntimeError, match=r"you must call `.launch\(\)`"): fabric._validate_launched()