Remove requirement to call `Fabric.launch()` with DP strategy (#17931)

This commit is contained in:
Adrian Wälchli 2023-06-30 01:20:01 -07:00 committed by GitHub
parent 28beb8a478
commit 5d7669af46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 2 deletions

View File

@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed check for FSDP's flat parameters in all parameter groups ([#17914](https://github.com/Lightning-AI/lightning/pull/17914))
- Removed the need to call `.launch()` when using the DP-strategy (`strategy="dp"`) ([#17931](https://github.com/Lightning-AI/lightning/pull/17931))
- Fixed automatic step tracking in Fabric's CSVLogger ([#17942](https://github.com/Lightning-AI/lightning/pull/17942))

View File

@ -33,7 +33,14 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.plugins import Precision # avoid circular imports: # isort: split
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.connector import _Connector, _is_using_cli, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
from lightning.fabric.strategies import (
DataParallelStrategy,
DeepSpeedStrategy,
FSDPStrategy,
SingleDeviceStrategy,
Strategy,
XLAStrategy,
)
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
from lightning.fabric.strategies.strategy import _Sharded, TBroadcast
from lightning.fabric.utilities import move_data_to_device
@ -919,7 +926,7 @@ class Fabric:
setattr(self, "run", partial(self._wrap_and_launch, self.run))
def _validate_launched(self) -> None:
if not self._launched and not isinstance(self._strategy, SingleDeviceStrategy):
if not self._launched and not isinstance(self._strategy, (SingleDeviceStrategy, DataParallelStrategy)):
raise RuntimeError(
"To use Fabric with more than one device, you must call `.launch()` or use the CLI:"
" `lightning run model --help`."

View File

@ -27,6 +27,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Samp
from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.strategies import (
DataParallelStrategy,
DDPStrategy,
DeepSpeedStrategy,
ParallelStrategy,
@ -1134,6 +1135,8 @@ def test_verify_launch_called():
assert not fabric._launched
fabric._strategy = Mock(spec=SingleDeviceStrategy)
fabric._validate_launched()
fabric._strategy = Mock(spec=DataParallelStrategy)
fabric._validate_launched()
fabric._strategy = Mock(spec=DDPStrategy)
with pytest.raises(RuntimeError, match=r"you must call `.launch\(\)`"):
fabric._validate_launched()