diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7021081d6c..1b2eed56b8 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -517,6 +517,9 @@ class AcceleratorConnector(object): rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.') self._distrib_type = None + # finished configuring self._distrib_type, check ipython environment + self.check_interactive_compatibility() + # for DDP overwrite nb processes by requested GPUs if ( self._device_type == DeviceType.GPU @@ -558,6 +561,19 @@ class AcceleratorConnector(object): else: self.num_processes = hvd.local_size() + def check_interactive_compatibility(self): + """ + Raises a `MisconfigurationException` if the accelerator and/or plugin + is not compatible with an interactive environment + """ + from pytorch_lightning.utilities import _IS_INTERACTIVE + if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible(): + raise MisconfigurationException( + f"Selected distributed backend {self._distrib_type} is not compatible with an interactive" + " environment. Run your code as a script, or choose one of the compatible backends:" + f" {', '.join(DistributedType.interactive_compatible_types())}" + ) + def check_horovod(self): """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod.""" if not _HOROVOD_AVAILABLE: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index cf3aa06f30..3e2ee3e51e 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """General utilities""" - import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 @@ -33,6 +32,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _HOROVOD_AVAILABLE, _HYDRA_AVAILABLE, _HYDRA_EXPERIMENTAL_AVAILABLE, + _IS_INTERACTIVE, _module_available, _NATIVE_AMP_AVAILABLE, _OMEGACONF_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 3e4add4fb6..169481fa63 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -13,14 +13,14 @@ # limitations under the License. """Enumerated utilities""" from enum import Enum -from typing import Union +from typing import List, Optional, Union class LightningEnum(str, Enum): """ Type of any enumerator with allowed comparison to string invariant to cases. """ @classmethod - def from_str(cls, value: str) -> 'LightningEnum': + def from_str(cls, value: str) -> Optional['LightningEnum']: statuses = [status for status in dir(cls) if not status.startswith('_')] for st in statuses: if st.lower() == value.lower(): @@ -31,7 +31,7 @@ class LightningEnum(str, Enum): other = other.value if isinstance(other, Enum) else str(other) return self.value.lower() == other.lower() - def __hash__(self): + def __hash__(self) -> int: # re-enable hashtable so it can be used as a dict key or in a set # example: set(LightningEnum) return hash(self.name) @@ -58,6 +58,16 @@ class DistributedType(LightningEnum): >>> DistributedType.DDP2 in ('ddp2', ) True """ + + @staticmethod + def interactive_compatible_types() -> List['DistributedType']: + """Returns a list containing interactive compatible DistributeTypes""" + return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN] + + def is_interactive_compatible(self) -> bool: + """Returns whether self is interactive compatible""" + return self in DistributedType.interactive_compatible_types() + DP = 'dp' DDP = 'ddp' DDP2 = 'ddp2' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 8024997382..41a13d6c67 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -14,6 +14,7 @@ """General utilities""" import operator import platform +import sys from distutils.version import LooseVersion from importlib.util import find_spec @@ -49,10 +50,11 @@ def _compare_version(package: str, op, version) -> bool: _IS_WINDOWS = platform.system() == "Windows" +_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765 _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") -_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none']) + _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') @@ -65,6 +67,7 @@ _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _OMEGACONF_AVAILABLE = _module_available("omegaconf") _RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc') +_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none']) _TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available('torchvision') _XLA_AVAILABLE = _module_available("torch_xla") diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 82b631807c..e41c891797 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -32,6 +32,7 @@ from pytorch_lightning.plugins import ( SingleDevicePlugin, ) from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -387,6 +388,19 @@ def test_dist_backend_accelerator_mapping(device_count_mock): trainer.fit(model) +@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) +@mock.patch('torch.cuda.device_count', return_value=2) +def test_ipython_incompatible_backend_error(*_): + with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"): + Trainer(accelerator="ddp", gpus=2) + + with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"): + Trainer(accelerator="ddp_cpu", num_processes=2) + + with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"): + Trainer(accelerator="ddp2", gpus=2) + + @pytest.mark.parametrize( ["accelerator", "plugin"], [('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],