Ensure accelerator is valid if running interactively (#5970)
Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
863a70c294
commit
ebabe56f4e
|
@ -517,6 +517,9 @@ class AcceleratorConnector(object):
|
|||
rank_zero_warn('You are running on single node with no parallelization, so distributed has no effect.')
|
||||
self._distrib_type = None
|
||||
|
||||
# finished configuring self._distrib_type, check ipython environment
|
||||
self.check_interactive_compatibility()
|
||||
|
||||
# for DDP overwrite nb processes by requested GPUs
|
||||
if (
|
||||
self._device_type == DeviceType.GPU
|
||||
|
@ -558,6 +561,19 @@ class AcceleratorConnector(object):
|
|||
else:
|
||||
self.num_processes = hvd.local_size()
|
||||
|
||||
def check_interactive_compatibility(self):
|
||||
"""
|
||||
Raises a `MisconfigurationException` if the accelerator and/or plugin
|
||||
is not compatible with an interactive environment
|
||||
"""
|
||||
from pytorch_lightning.utilities import _IS_INTERACTIVE
|
||||
if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible():
|
||||
raise MisconfigurationException(
|
||||
f"Selected distributed backend {self._distrib_type} is not compatible with an interactive"
|
||||
" environment. Run your code as a script, or choose one of the compatible backends:"
|
||||
f" {', '.join(DistributedType.interactive_compatible_types())}"
|
||||
)
|
||||
|
||||
def check_horovod(self):
|
||||
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
|
||||
if not _HOROVOD_AVAILABLE:
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""General utilities"""
|
||||
|
||||
import numpy
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401
|
||||
|
@ -33,6 +32,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
_HOROVOD_AVAILABLE,
|
||||
_HYDRA_AVAILABLE,
|
||||
_HYDRA_EXPERIMENTAL_AVAILABLE,
|
||||
_IS_INTERACTIVE,
|
||||
_module_available,
|
||||
_NATIVE_AMP_AVAILABLE,
|
||||
_OMEGACONF_AVAILABLE,
|
||||
|
|
|
@ -13,14 +13,14 @@
|
|||
# limitations under the License.
|
||||
"""Enumerated utilities"""
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
class LightningEnum(str, Enum):
|
||||
""" Type of any enumerator with allowed comparison to string invariant to cases. """
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value: str) -> 'LightningEnum':
|
||||
def from_str(cls, value: str) -> Optional['LightningEnum']:
|
||||
statuses = [status for status in dir(cls) if not status.startswith('_')]
|
||||
for st in statuses:
|
||||
if st.lower() == value.lower():
|
||||
|
@ -31,7 +31,7 @@ class LightningEnum(str, Enum):
|
|||
other = other.value if isinstance(other, Enum) else str(other)
|
||||
return self.value.lower() == other.lower()
|
||||
|
||||
def __hash__(self):
|
||||
def __hash__(self) -> int:
|
||||
# re-enable hashtable so it can be used as a dict key or in a set
|
||||
# example: set(LightningEnum)
|
||||
return hash(self.name)
|
||||
|
@ -58,6 +58,16 @@ class DistributedType(LightningEnum):
|
|||
>>> DistributedType.DDP2 in ('ddp2', )
|
||||
True
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def interactive_compatible_types() -> List['DistributedType']:
|
||||
"""Returns a list containing interactive compatible DistributeTypes"""
|
||||
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]
|
||||
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
"""Returns whether self is interactive compatible"""
|
||||
return self in DistributedType.interactive_compatible_types()
|
||||
|
||||
DP = 'dp'
|
||||
DDP = 'ddp'
|
||||
DDP2 = 'ddp2'
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""General utilities"""
|
||||
import operator
|
||||
import platform
|
||||
import sys
|
||||
from distutils.version import LooseVersion
|
||||
from importlib.util import find_spec
|
||||
|
||||
|
@ -49,10 +50,11 @@ def _compare_version(package: str, op, version) -> bool:
|
|||
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
|
||||
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
|
||||
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
|
||||
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
|
||||
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
|
||||
|
||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
|
||||
|
@ -65,6 +67,7 @@ _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
|
|||
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
||||
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
|
||||
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
|
||||
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
|
||||
_TORCHVISION_AVAILABLE = _module_available('torchvision')
|
||||
_XLA_AVAILABLE = _module_available("torch_xla")
|
||||
|
|
|
@ -32,6 +32,7 @@ from pytorch_lightning.plugins import (
|
|||
SingleDevicePlugin,
|
||||
)
|
||||
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -387,6 +388,19 @@ def test_dist_backend_accelerator_mapping(device_count_mock):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True)
|
||||
@mock.patch('torch.cuda.device_count', return_value=2)
|
||||
def test_ipython_incompatible_backend_error(*_):
|
||||
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
|
||||
Trainer(accelerator="ddp", gpus=2)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="backend ddp is not compatible"):
|
||||
Trainer(accelerator="ddp_cpu", num_processes=2)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="backend ddp2 is not compatible"):
|
||||
Trainer(accelerator="ddp2", gpus=2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["accelerator", "plugin"],
|
||||
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
|
||||
|
|
Loading…
Reference in New Issue