Ensure accelerator is valid if running interactively (#5970)

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
ifsheldon 2021-02-23 16:23:50 +03:00 committed by GitHub
parent 863a70c294
commit ebabe56f4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 5 deletions

View File

@ -517,6 +517,9 @@ class AcceleratorConnector(object):
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
self._distrib_type = None
# finished configuring self._distrib_type, check ipython environment
self.check_interactive_compatibility()
# for DDP overwrite nb processes by requested GPUs
if (
self._device_type == DeviceType.GPU
@ -558,6 +561,19 @@ class AcceleratorConnector(object):
else:
self.num_processes = hvd.local_size()
def check_interactive_compatibility(self):
"""
Raises a `MisconfigurationException` if the accelerator and/or plugin
is not compatible with an interactive environment
"""
from pytorch_lightning.utilities import _IS_INTERACTIVE
if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible():
raise MisconfigurationException(
f"Selected distributed backend {self._distrib_type} is not compatible with an interactive"
" environment. Run your code as a script, or choose one of the compatible backends:"
f" {', '.join(DistributedType.interactive_compatible_types())}"
)
def check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not _HOROVOD_AVAILABLE:

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import numpy
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
@ -33,6 +32,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_HOROVOD_AVAILABLE,
_HYDRA_AVAILABLE,
_HYDRA_EXPERIMENTAL_AVAILABLE,
_IS_INTERACTIVE,
_module_available,
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,

View File

@ -13,14 +13,14 @@
# limitations under the License.
"""Enumerated utilities"""
from enum import Enum
from typing import Union
from typing import List, Optional, Union
class LightningEnum(str, Enum):
""" Type of any enumerator with allowed comparison to string invariant to cases. """
@classmethod
def from_str(cls, value: str) -> 'LightningEnum':
def from_str(cls, value: str) -> Optional['LightningEnum']:
statuses = [status for status in dir(cls) if not status.startswith('_')]
for st in statuses:
if st.lower() == value.lower():
@ -31,7 +31,7 @@ class LightningEnum(str, Enum):
other = other.value if isinstance(other, Enum) else str(other)
return self.value.lower() == other.lower()
def __hash__(self):
def __hash__(self) -> int:
# re-enable hashtable so it can be used as a dict key or in a set
# example: set(LightningEnum)
return hash(self.name)
@ -58,6 +58,16 @@ class DistributedType(LightningEnum):
>>> DistributedType.DDP2 in ('ddp2', )
True
"""
@staticmethod
def interactive_compatible_types() -> List['DistributedType']:
"""Returns a list containing interactive compatible DistributeTypes"""
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]
def is_interactive_compatible(self) -> bool:
"""Returns whether self is interactive compatible"""
return self in DistributedType.interactive_compatible_types()
DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'

View File

@ -14,6 +14,7 @@
"""General utilities"""
import operator
import platform
import sys
from distutils.version import LooseVersion
from importlib.util import find_spec
@ -49,10 +50,11 @@ def _compare_version(package: str, op, version) -> bool:
_IS_WINDOWS = platform.system() == "Windows"
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
@ -65,6 +67,7 @@ _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")

View File

@ -32,6 +32,7 @@ from pytorch_lightning.plugins import (
SingleDevicePlugin,
)
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
@ -387,6 +388,19 @@ def test_dist_backend_accelerator_mapping(device_count_mock):
trainer.fit(model)
@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_ipython_incompatible_backend_error(*_):
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
Trainer(accelerator="ddp", gpus=2)
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
Trainer(accelerator="ddp_cpu", num_processes=2)
with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"):
Trainer(accelerator="ddp2", gpus=2)
@pytest.mark.parametrize(
["accelerator", "plugin"],
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],