Make DDP subprocess the default launcher for multi-device (#16780)

This commit is contained in:
Adrian Wälchli 2023-02-20 12:20:50 +01:00 committed by GitHub
parent 3a0519143a
commit 81b7c30291
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 67 additions and 178 deletions

View File

@ -20,19 +20,15 @@ Let's say you have a batch size of 7 in your dataloader.
def train_dataloader(self):
return Dataset(..., batch_size=7)
In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED your effective batch size will be 7 * devices * num_nodes.
Whenever you use multiple devices and/or nodes, your effective batch size will be 7 * devices * num_nodes.
.. code-block:: python
# effective batch size = 7 * 8
Trainer(accelerator="gpu", devices=8, strategy="ddp")
Trainer(accelerator="gpu", devices=8, strategy="ddp_spawn")
Trainer(accelerator="gpu", devices=8, strategy="ddp_sharded")
Trainer(accelerator="gpu", devices=8, strategy=...)
# effective batch size = 7 * 8 * 10
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp")
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp_spawn")
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy="ddp_sharded")
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy=...)
.. note:: Huge batch sizes are actually really bad for convergence. Check out:
@ -45,13 +41,13 @@ In DDP, DDP_SPAWN, Deepspeed, DDP_SHARDED your effective batch size will be 7 *
How do I use multiple GPUs on Jupyter or Colab notebooks?
*********************************************************
To use multiple GPUs on notebooks, use the *DDP_SPAWN* or *DDP_NOTEBOOK* mode.
To use multiple GPUs on notebooks, use the *DDP_NOTEBOOK* mode.
.. code-block:: python
Trainer(accelerator="gpu", devices=4, strategy="ddp_notebook" | "ddp_spawn")
Trainer(accelerator="gpu", devices=4, strategy="ddp_notebook")
If you want to use other models, please launch your training via the command-shell.
If you want to use other strategies, please launch your training via the command-shell.
----
@ -59,7 +55,7 @@ If you want to use other models, please launch your training via the command-she
I'm getting errors related to Pickling. What do I do?
*****************************************************
Pickle is Python's mechanism for serializing and unserializing data. A majority of distributed modes require that your code is fully pickle compliant. If you run into an issue with pickling try the following to figure out the issue
Pickle is Python's mechanism for serializing and unserializing data. Some distributed modes require that your code is fully pickle compliant. If you run into an issue with pickling, try the following to figure out the issue.
.. code-block:: python
@ -68,14 +64,14 @@ Pickle is Python's mechanism for serializing and unserializing data. A majority
model = YourModel()
pickle.dumps(model)
If you `ddp` your code doesn't need to be pickled.
.. code-block:: python
Trainer(accelerator="gpu", devices=4, strategy="ddp")
If you use `ddp_spawn` the pickling requirement remains. This is a limitation of Python.
For example, the `ddp_spawn` strategy has the pickling requirement. This is a limitation of Python.
.. code-block:: python
Trainer(accelerator="gpu", devices=4, strategy="ddp_spawn")
If you use `ddp`, your code doesn't need to be pickled:
.. code-block:: python
Trainer(accelerator="gpu", devices=4, strategy="ddp")

View File

@ -26,7 +26,7 @@ Lightning supports multiple ways of doing distributed training.
- Notebook/Fork (``strategy='ddp_notebook'``)
.. note::
If you request multiple GPUs or nodes without setting a mode, DDP Spawn will be automatically used.
If you request multiple GPUs or nodes without setting a strategy, DDP will be automatically used.
For a deeper understanding of what Lightning is doing, feel free to read this
`guide <https://medium.com/@_willfalcon/9-tips-for-training-lightning-fast-neural-networks-in-pytorch-8e63a502f565>`_.
@ -196,13 +196,9 @@ Comparison of DDP variants and tradeoffs
- No
- Yes
- No
* - Is the guard ``if __name__=="__main__"`` required?
- Yes
- Yes
- No
* - Limitations in the main process
- None
- None
- The state of objects is not up-to-date after returning to the main process (`Trainer.fit()` etc). Only the model parameters get transferred over.
- GPU operations such as moving tensors to the GPU or calling ``torch.cuda`` functions before invoking ``Trainer.fit`` is not allowed.
* - Process creation time
- Slow

View File

@ -1055,32 +1055,23 @@ By setting to False, you have to add your own distributed sampler:
strategy
^^^^^^^^
Supports passing different training strategies with aliases (ddp, ddp_spawn, etc) as well as custom strategies.
Supports passing different training strategies with aliases (ddp, fsdp, etc) as well as configured strategies.
.. code-block:: python
# Training with the DistributedDataParallel strategy on 4 GPUs
# Data-parallel training with the DDP strategy on 4 GPUs
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)
# Training with the DDP Spawn strategy using 4 cpu processes
trainer = Trainer(strategy="ddp_spawn", accelerator="cpu", devices=4)
# Model-parallel training with the FSDP strategy on 4 GPUs
trainer = Trainer(strategy="fsdp", accelerator="gpu", devices=4)
.. note:: Additionally, you can pass your custom strategy to the ``strategy`` argument.
Additionally, you can pass a strategy object.
.. code-block:: python
from pytorch_lightning.strategies import DDPStrategy
class CustomDDPStrategy(DDPStrategy):
def configure_ddp(self):
self._model = MyCustomDistributedDataParallel(
self.model,
device_ids=...,
)
trainer = Trainer(strategy=CustomDDPStrategy(), accelerator="gpu", devices=2)
trainer = Trainer(strategy=DDPStrategy(static_graph=True), accelerator="gpu", devices=2)
See Also:
- :ref:`Multi GPU Training <multi_gpu>`.

View File

@ -74,7 +74,7 @@ The below table lists all relevant strategies available in Lightning with their
- Strategy for Fully Sharded Data Parallel training. :ref:`Learn more. <advanced/model_parallel:Fully Sharded Training>`
* - ddp_spawn
- :class:`~pytorch_lightning.strategies.DDPSpawnStrategy`
- Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>`
- Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training finishes. Useful for debugging. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel Spawn>`
* - ddp
- :class:`~pytorch_lightning.strategies.DDPStrategy`
- Strategy for multi-process single-device training on one or multiple nodes. :ref:`Learn more. <accelerators/gpu_intermediate:Distributed Data Parallel>`

View File

@ -202,7 +202,6 @@ DataLoaders
Lightning uses :class:`~torch.utils.data.DataLoader` to handle all the data flow through the system. Whenever you structure dataloaders,
make sure to tune the number of workers for maximum efficiency.
.. warning:: Make sure not to use ``Trainer(strategy="ddp_spawn")`` with ``num_workers>0`` in the DataLoader or you will bottleneck you code.
DataModules
===========

View File

@ -40,6 +40,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16767](https://github.com/Lightning-AI/lightning/pull/16767))
- The selection `Fabric(strategy="ddp_spawn", ...)` no longer falls back to "ddp" when a cluster environment gets detected ([#16780](https://github.com/Lightning-AI/lightning/pull/16780))
### Deprecated
-

View File

@ -415,14 +415,6 @@ class _Connector:
# TODO this logic should apply to both str and object config
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag
if strategy_flag == "ddp_spawn" and (
TorchElasticEnvironment.detect()
or KubeflowEnvironment.detect()
or SLURMEnvironment.detect()
or LSFEnvironment.detect()
or MPIEnvironment.detect()
):
strategy_flag = "ddp"
if strategy_flag == "dp" and self._accelerator_flag == "cpu":
rank_zero_warn(f"{strategy_flag!r} is not supported on CPUs, hence setting `strategy='ddp'`.")
strategy_flag = "ddp"

View File

@ -109,6 +109,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed arguments for precision settings (from [64|32|16|bf16] to ["64-true"|"32-true"|"16-mixed"|"bf16-mixed"]) ([#16783](https://github.com/Lightning-AI/lightning/pull/16783))
- When using multiple devices, the strategy now defaults to "ddp" instead of "ddp_spawn" when none is set ([#16780](https://github.com/Lightning-AI/lightning/pull/16780))
- The selection `Trainer(strategy="ddp_spawn", ...)` no longer falls back to "ddp" when a cluster environment gets detected ([#16780](https://github.com/Lightning-AI/lightning/pull/16780))
### Deprecated
-

View File

@ -450,12 +450,9 @@ class AcceleratorConnector:
device = "cpu"
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_device"
return SingleDeviceStrategy(device=device) # type: ignore
if len(self._parallel_devices) > 1:
if _IS_INTERACTIVE:
return "ddp_fork"
return "ddp_spawn"
return DDPStrategy.strategy_name
if len(self._parallel_devices) > 1 and _IS_INTERACTIVE:
return "ddp_fork"
return "ddp"
def _check_strategy_and_fallback(self) -> None:
"""Checks edge cases when the strategy selection was a string input, and we need to fall back to a
@ -464,18 +461,6 @@ class AcceleratorConnector:
# TODO this logic should apply to both str and object config
strategy_flag = "" if isinstance(self._strategy_flag, Strategy) else self._strategy_flag
if strategy_flag in (
"ddp_spawn",
"ddp_spawn_find_unused_parameters_false",
"ddp_spawn_find_unused_parameters_true",
) and (
TorchElasticEnvironment.detect()
or KubeflowEnvironment.detect()
or SLURMEnvironment.detect()
or LSFEnvironment.detect()
or MPIEnvironment.detect()
):
strategy_flag = "ddp"
if (
strategy_flag in FSDPStrategy.get_registered_strategies() or isinstance(self._strategy_flag, FSDPStrategy)
) and self._accelerator_flag not in ("cuda", "gpu"):

View File

@ -537,9 +537,9 @@ def test_strategy_choice_ddp_spawn(*_):
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
@pytest.mark.parametrize("job_name,expected_env", [("some_name", SLURMEnvironment), ("bash", LightningEnvironment)])
@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy])
@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy])
def test_strategy_choice_ddp_slurm(_, strategy, job_name, expected_env):
if not isinstance(strategy, str):
if strategy and not isinstance(strategy, str):
strategy = strategy()
with mock.patch.dict(
@ -574,8 +574,8 @@ def test_strategy_choice_ddp_slurm(_, strategy, job_name, expected_env):
)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_strategy_choice_ddp_te(*_):
connector = _Connector(strategy="ddp", accelerator="gpu", devices=2)
def test_strategy_choice_ddp_torchelastic(*_):
connector = _Connector(accelerator="gpu", devices=2)
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, TorchElasticEnvironment)
@ -583,26 +583,6 @@ def test_strategy_choice_ddp_te(*_):
assert connector.strategy.local_rank == 1
@mock.patch.dict(
os.environ,
{
"WORLD_SIZE": "2",
"LOCAL_WORLD_SIZE": "2",
"RANK": "1",
"LOCAL_RANK": "1",
"GROUP_RANK": "0",
"TORCHELASTIC_RUN_ID": "1",
},
)
def test_strategy_choice_ddp_cpu_te():
connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, TorchElasticEnvironment)
assert connector.strategy.cluster_environment.local_rank() == 1
assert connector.strategy.local_rank == 1
@mock.patch.dict(
os.environ,
{
@ -614,10 +594,10 @@ def test_strategy_choice_ddp_cpu_te():
"RANK": "1",
},
)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=1)
@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2)
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
def test_strategy_choice_ddp_kubeflow(*_):
connector = _Connector(strategy="ddp", accelerator="gpu", devices=1)
connector = _Connector(accelerator="gpu", devices=2)
assert isinstance(connector.accelerator, CUDAAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, KubeflowEnvironment)
@ -636,7 +616,7 @@ def test_strategy_choice_ddp_kubeflow(*_):
},
)
def test_strategy_choice_ddp_cpu_kubeflow():
connector = _Connector(strategy="ddp_spawn", accelerator="cpu", devices=2)
connector = _Connector(accelerator="cpu", devices=2)
assert isinstance(connector.accelerator, CPUAccelerator)
assert isinstance(connector.strategy, DDPStrategy)
assert isinstance(connector.strategy.cluster_environment, KubeflowEnvironment)
@ -656,7 +636,7 @@ def test_strategy_choice_ddp_cpu_kubeflow():
"SLURM_LOCALID": "0",
},
)
@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy()])
@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy()])
def test_strategy_choice_ddp_cpu_slurm(strategy):
connector = _Connector(strategy=strategy, accelerator="cpu", devices=2)
assert isinstance(connector.accelerator, CPUAccelerator)

View File

@ -20,7 +20,7 @@ from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.strategies import DDPStrategy
def test_pluggable_accelerator():
def test_pluggable_accelerator(mps_count_0, cuda_count_2):
class TestAccelerator(Accelerator):
def setup_device(self, device: torch.device) -> None:
pass

View File

@ -238,7 +238,7 @@ def test_auto_parameters_tying_tpus_nested_module(tmpdir):
assert torch.all(torch.eq(model.net_a.layer.weight, model.net_b.layer.weight))
def test_tpu_invalid_raises(tpu_available):
def test_tpu_invalid_raises(tpu_available, mps_count_0):
strategy = XLAStrategy(accelerator=TPUAccelerator(), precision_plugin=PrecisionPlugin())
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
Trainer(strategy=strategy, devices=8)
@ -248,7 +248,7 @@ def test_tpu_invalid_raises(tpu_available):
Trainer(strategy=strategy, devices=8)
def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available):
def test_tpu_invalid_raises_set_precision_with_strategy(tpu_available, mps_count_0):
accelerator = TPUAccelerator()
strategy = XLAStrategy(accelerator=accelerator, precision_plugin=PrecisionPlugin())
with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"):

View File

@ -94,7 +94,7 @@ def test_amp_gpus(tmpdir, strategy, precision, devices):
max_epochs=1,
accelerator="gpu",
devices=devices,
strategy=strategy,
strategy=("ddp_spawn" if strategy is None and devices > 1 else strategy),
precision=precision,
)

View File

@ -43,6 +43,7 @@ def test_multi_gpu_none_backend(tmpdir):
limit_train_batches=0.2,
limit_val_batches=0.2,
accelerator="gpu",
strategy="ddp_spawn",
devices=2,
)
@ -62,6 +63,7 @@ def test_single_gpu_model(tmpdir, devices):
limit_val_batches=0.1,
accelerator="gpu",
devices=devices,
strategy="ddp_spawn",
)
model = BoringModel()

View File

@ -75,7 +75,7 @@ def test_torch_distributed_backend_invalid(cuda_count_2, tmpdir):
@RunIf(skip_windows=True)
@mock.patch("torch.cuda.set_device")
@mock.patch("lightning.pytorch.accelerators.cuda._check_cuda_matmul_precision")
def test_ddp_torch_dist_is_available_in_setup(_, __, cuda_count_1, tmpdir):
def test_ddp_torch_dist_is_available_in_setup(_, __, cuda_count_1, mps_count_0, tmpdir):
"""Test to ensure torch distributed is available within the setup hook using ddp."""
class TestModel(BoringModel):

View File

@ -46,14 +46,11 @@ class BoringCallbackDDPSpawnModel(BoringModel):
@RunIf(skip_windows=True)
def test_ddp_cpu():
"""Tests if device is set correctly when training for DDPSpawnStrategy."""
trainer = Trainer(devices=2, accelerator="cpu", fast_dev_run=True)
trainer = Trainer(devices=2, strategy="ddp_spawn", accelerator="cpu", fast_dev_run=True)
# assert strategy attributes for device setting
assert isinstance(trainer.strategy, DDPSpawnStrategy)
assert trainer.strategy.root_device == torch.device("cpu")
model = BoringModelDDPCPU()
trainer.fit(model)
@ -125,7 +122,7 @@ def test_ddp_spawn_configure_ddp(tmpdir):
@mock.patch("torch.distributed.init_process_group")
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group):
def test_ddp_spawn_strategy_set_timeout(mock_init_process_group, cuda_count_2, mps_count_0):
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
model = BoringModel()

View File

@ -115,7 +115,7 @@ def test_incorrect_ddp_script_spawning(tmpdir):
@RunIf(skip_windows=True)
def test_ddp_configure_ddp():
def test_ddp_configure_ddp(cuda_count_2, mps_count_0):
"""Tests with ddp strategy."""
model = BoringModel()
ddp_strategy = DDPStrategy()
@ -229,7 +229,7 @@ def test_configure_launcher_create_processes_externally():
@mock.patch("torch.distributed.init_process_group")
def test_ddp_strategy_set_timeout(mock_init_process_group):
def test_ddp_strategy_set_timeout(mock_init_process_group, cuda_count_2, mps_count_0):
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
model = BoringModel()

View File

@ -26,7 +26,6 @@ import lightning.pytorch
from lightning.fabric.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
LSFEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
XLAEnvironment,
@ -222,43 +221,6 @@ def test_custom_accelerator(cuda_count_0):
assert trainer._accelerator_connector.strategy is strategy
@pytest.mark.parametrize(
"env_vars,expected_environment",
[
(
{
"SLURM_NTASKS": "2",
"SLURM_NTASKS_PER_NODE": "1",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_PROCID": "0",
"SLURM_LOCALID": "0",
},
SLURMEnvironment,
),
(
{
"LSB_JOBID": "1",
"LSB_DJOB_RANKFILE": "SOME_RANK_FILE",
"JSM_NAMESPACE_LOCAL_RANK": "1",
"JSM_NAMESPACE_SIZE": "20",
"JSM_NAMESPACE_RANK": "1",
},
LSFEnvironment,
),
],
)
@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._read_hosts", return_value=["node0", "node1"])
@mock.patch("lightning.fabric.plugins.environments.lsf.LSFEnvironment._get_node_rank", return_value=0)
def test_fallback_from_ddp_spawn_to_ddp_on_cluster(_, __, env_vars, expected_environment):
with mock.patch.dict(os.environ, env_vars, clear=True):
trainer = Trainer(strategy="ddp_spawn", accelerator="cpu", devices=2)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert isinstance(trainer.strategy.cluster_environment, expected_environment)
@RunIf(mps=False)
def test_interactive_incompatible_backend_error(cuda_count_2, monkeypatch):
monkeypatch.setattr(lightning.pytorch.trainer.connectors.accelerator_connector, "_IS_INTERACTIVE", True)
@ -333,7 +295,7 @@ def test_accelerator_gpu():
assert isinstance(trainer.accelerator, CUDAAccelerator)
@pytest.mark.parametrize(["devices", "strategy_class"], [(1, SingleDeviceStrategy), (5, DDPSpawnStrategy)])
@pytest.mark.parametrize(["devices", "strategy_class"], [(1, SingleDeviceStrategy), (5, DDPStrategy)])
def test_accelerator_cpu_with_devices(devices, strategy_class):
trainer = Trainer(accelerator="cpu", devices=devices)
assert trainer.num_devices == devices
@ -343,7 +305,7 @@ def test_accelerator_cpu_with_devices(devices, strategy_class):
@RunIf(min_cuda_gpus=2)
@pytest.mark.parametrize(
["devices", "strategy_class"], [(1, SingleDeviceStrategy), ([1], SingleDeviceStrategy), (2, DDPSpawnStrategy)]
["devices", "strategy_class"], [(1, SingleDeviceStrategy), ([1], SingleDeviceStrategy), (2, DDPStrategy)]
)
def test_accelerator_gpu_with_devices(devices, strategy_class):
trainer = Trainer(accelerator="gpu", devices=devices)
@ -478,9 +440,9 @@ def test_strategy_choice_ddp_cuda(strategy, expected_cls, mps_count_0, cuda_coun
@pytest.mark.parametrize("job_name,expected_env", [("some_name", SLURMEnvironment), ("bash", LightningEnvironment)])
@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy])
@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy])
def test_strategy_choice_ddp_slurm(cuda_count_2, strategy, job_name, expected_env):
if not isinstance(strategy, str):
if strategy and not isinstance(strategy, str):
strategy = strategy()
with mock.patch.dict(
@ -515,8 +477,8 @@ def test_strategy_choice_ddp_slurm(cuda_count_2, strategy, job_name, expected_en
)
@mock.patch("torch.cuda.set_device")
@mock.patch("lightning.pytorch.strategies.DDPStrategy.setup_distributed", autospec=True)
def test_strategy_choice_ddp_te(_, __, mps_count_0, cuda_count_2):
trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=2)
def test_strategy_choice_ddp_torchelastic(_, __, mps_count_0, cuda_count_2):
trainer = Trainer(fast_dev_run=True, accelerator="gpu", devices=2)
assert isinstance(trainer.accelerator, CUDAAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment)
@ -524,27 +486,6 @@ def test_strategy_choice_ddp_te(_, __, mps_count_0, cuda_count_2):
assert trainer.strategy.local_rank == 1
@mock.patch.dict(
os.environ,
{
"WORLD_SIZE": "2",
"LOCAL_WORLD_SIZE": "2",
"RANK": "1",
"LOCAL_RANK": "1",
"GROUP_RANK": "0",
"TORCHELASTIC_RUN_ID": "1",
},
)
@mock.patch("lightning.pytorch.strategies.DDPStrategy.setup_distributed", autospec=True)
def test_strategy_choice_ddp_cpu_te(cuda_count_0):
trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert isinstance(trainer.strategy.cluster_environment, TorchElasticEnvironment)
assert trainer.strategy.cluster_environment.local_rank() == 1
assert trainer.strategy.local_rank == 1
@mock.patch.dict(
os.environ,
{
@ -558,8 +499,8 @@ def test_strategy_choice_ddp_cpu_te(cuda_count_0):
)
@mock.patch("torch.cuda.set_device")
@mock.patch("lightning.pytorch.strategies.DDPStrategy.setup_distributed", autospec=True)
def test_strategy_choice_ddp_kubeflow(_, __, mps_count_0, cuda_count_1):
trainer = Trainer(fast_dev_run=True, strategy="ddp", accelerator="gpu", devices=1)
def test_strategy_choice_ddp_kubeflow(_, __, mps_count_0, cuda_count_2):
trainer = Trainer(fast_dev_run=True, accelerator="gpu", devices=2)
assert isinstance(trainer.accelerator, CUDAAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert isinstance(trainer.strategy.cluster_environment, KubeflowEnvironment)
@ -579,7 +520,7 @@ def test_strategy_choice_ddp_kubeflow(_, __, mps_count_0, cuda_count_1):
)
@mock.patch("lightning.pytorch.strategies.DDPStrategy.setup_distributed", autospec=True)
def test_strategy_choice_ddp_cpu_kubeflow(cuda_count_0):
trainer = Trainer(fast_dev_run=True, strategy="ddp_spawn", accelerator="cpu", devices=2)
trainer = Trainer(fast_dev_run=True, accelerator="cpu", devices=2)
assert isinstance(trainer.accelerator, CPUAccelerator)
assert isinstance(trainer.strategy, DDPStrategy)
assert isinstance(trainer.strategy.cluster_environment, KubeflowEnvironment)
@ -600,7 +541,7 @@ def test_strategy_choice_ddp_cpu_kubeflow(cuda_count_0):
},
)
@mock.patch("lightning.pytorch.strategies.DDPStrategy.setup_distributed", autospec=True)
@pytest.mark.parametrize("strategy", ["ddp", DDPStrategy()])
@pytest.mark.parametrize("strategy", [None, "ddp", DDPStrategy()])
def test_strategy_choice_ddp_cpu_slurm(cuda_count_0, strategy):
trainer = Trainer(fast_dev_run=True, strategy=strategy, accelerator="cpu", devices=2)
assert isinstance(trainer.accelerator, CPUAccelerator)

View File

@ -375,7 +375,7 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_s
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
@pytest.mark.parametrize("mode", ("min_size", "max_size_cycle", "sequential"))
@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, use_combined_loader):
def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, use_combined_loader, mps_count_0):
"""When providing a CombinedLoader as the training data, it should be correctly receive the distributed
samplers."""
dim = 3

View File

@ -1893,7 +1893,7 @@ def test_detect_anomaly_nan(tmpdir):
CUDAAccelerator,
1,
),
({"strategy": None, "accelerator": "cuda", "devices": 2}, DDPSpawnStrategy, "ddp_spawn", CUDAAccelerator, 2),
({"strategy": None, "accelerator": "cuda", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2),
({"strategy": "ddp", "accelerator": "cuda", "devices": 2}, DDPStrategy, "ddp", CUDAAccelerator, 2),
({"strategy": "ddp", "accelerator": "cpu", "devices": 2}, DDPStrategy, "ddp", CPUAccelerator, 2),
(