Remove requirement to call `Fabric.launch()` with DP strategy (#17931)
This commit is contained in:
parent
28beb8a478
commit
5d7669af46
|
@ -142,6 +142,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed check for FSDP's flat parameters in all parameter groups ([#17914](https://github.com/Lightning-AI/lightning/pull/17914))
|
||||
|
||||
|
||||
- Removed the need to call `.launch()` when using the DP-strategy (`strategy="dp"`) ([#17931](https://github.com/Lightning-AI/lightning/pull/17931))
|
||||
|
||||
|
||||
- Fixed automatic step tracking in Fabric's CSVLogger ([#17942](https://github.com/Lightning-AI/lightning/pull/17942))
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,14 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
|||
from lightning.fabric.plugins import Precision # avoid circular imports: # isort: split
|
||||
from lightning.fabric.accelerators.accelerator import Accelerator
|
||||
from lightning.fabric.connector import _Connector, _is_using_cli, _PLUGIN_INPUT, _PRECISION_INPUT
|
||||
from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
|
||||
from lightning.fabric.strategies import (
|
||||
DataParallelStrategy,
|
||||
DeepSpeedStrategy,
|
||||
FSDPStrategy,
|
||||
SingleDeviceStrategy,
|
||||
Strategy,
|
||||
XLAStrategy,
|
||||
)
|
||||
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
|
||||
from lightning.fabric.strategies.strategy import _Sharded, TBroadcast
|
||||
from lightning.fabric.utilities import move_data_to_device
|
||||
|
@ -919,7 +926,7 @@ class Fabric:
|
|||
setattr(self, "run", partial(self._wrap_and_launch, self.run))
|
||||
|
||||
def _validate_launched(self) -> None:
|
||||
if not self._launched and not isinstance(self._strategy, SingleDeviceStrategy):
|
||||
if not self._launched and not isinstance(self._strategy, (SingleDeviceStrategy, DataParallelStrategy)):
|
||||
raise RuntimeError(
|
||||
"To use Fabric with more than one device, you must call `.launch()` or use the CLI:"
|
||||
" `lightning run model --help`."
|
||||
|
|
|
@ -27,6 +27,7 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Samp
|
|||
from lightning.fabric.fabric import Fabric
|
||||
from lightning.fabric.plugins import Precision
|
||||
from lightning.fabric.strategies import (
|
||||
DataParallelStrategy,
|
||||
DDPStrategy,
|
||||
DeepSpeedStrategy,
|
||||
ParallelStrategy,
|
||||
|
@ -1134,6 +1135,8 @@ def test_verify_launch_called():
|
|||
assert not fabric._launched
|
||||
fabric._strategy = Mock(spec=SingleDeviceStrategy)
|
||||
fabric._validate_launched()
|
||||
fabric._strategy = Mock(spec=DataParallelStrategy)
|
||||
fabric._validate_launched()
|
||||
fabric._strategy = Mock(spec=DDPStrategy)
|
||||
with pytest.raises(RuntimeError, match=r"you must call `.launch\(\)`"):
|
||||
fabric._validate_launched()
|
||||
|
|
Loading…
Reference in New Issue