Fix distributed types support for CPUs (#8667)

This commit is contained in:
Kaushik B 2021-08-02 16:42:28 +05:30 committed by GitHub
parent 85bba06529
commit 850416f0a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 5 deletions

View File

@ -779,12 +779,13 @@ class AcceleratorConnector:
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
# DP and DDP2 cannot run without GPU
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _use_cpu:
rank_zero_warn(
"You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`."
)
# todo: in some cases it yield in comparison None and int
if (self.num_nodes and self.num_nodes > 1) or (self.num_processes and self.num_processes > 1):
self._distrib_type = DistributedType.DDP
if self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
rank_zero_warn(
f"{self._distrib_type} is not supported on CPUs, hence setting the distributed type to `ddp`."
)
self._distrib_type = DistributedType.DDP
else:
rank_zero_warn("You are running on single node with no parallelization, so distributed has no effect.")
self._distrib_type = None

View File

@ -42,6 +42,7 @@ from pytorch_lightning.plugins.environments import (
SLURMEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@ -613,3 +614,12 @@ def test_devices_with_cpu_only_supports_integer():
with pytest.raises(MisconfigurationException, match="The flag `devices` only supports integer"):
Trainer(accelerator="cpu", devices="1,3")
@pytest.mark.parametrize("training_type", ["ddp2", "dp"])
def test_unsupported_distrib_types_on_cpu(training_type):
with pytest.warns(UserWarning, match="is not supported on CPUs, hence setting the distributed type to `ddp`."):
trainer = Trainer(accelerator=training_type, num_processes=2)
assert trainer._distrib_type == DistributedType.DDP

View File

@ -1154,6 +1154,14 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
dict(accelerator="ddp2", gpus=2),
dict(_distrib_type=DistributedType.DDP2, _device_type=DeviceType.GPU, num_gpus=2, num_processes=1),
),
(
dict(accelerator="ddp2", num_processes=2, gpus=None),
dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2),
),
(
dict(accelerator="dp", num_processes=2, gpus=None),
dict(_distrib_type=DistributedType.DDP, _device_type=DeviceType.CPU, num_gpus=0, num_processes=2),
),
],
)
def test_trainer_config(trainer_kwargs, expected, monkeypatch):