From d0f54609de9c4f5ed797fe6af86bd193679c9928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Feb 2022 14:50:54 +0100 Subject: [PATCH] Fix `is_interactive_compatible` logic after AcceleratorConnector rewrite (#12008) * fix is_interactive_compatible * improve tests * update message * address review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../strategies/launchers/base.py | 5 ++++ .../strategies/launchers/spawn.py | 7 ++++++ .../strategies/launchers/subprocess_script.py | 4 ++++ .../strategies/launchers/xla_spawn.py | 4 ++++ .../connectors/accelerator_connector.py | 14 ++++------- pytorch_lightning/utilities/enums.py | 2 -- .../test_accelerator_connector.py | 23 +++++++++++++++---- 7 files changed, 42 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/base.py b/pytorch_lightning/strategies/launchers/base.py index 293c0a2ce4..2acf54afef 100644 --- a/pytorch_lightning/strategies/launchers/base.py +++ b/pytorch_lightning/strategies/launchers/base.py @@ -26,6 +26,11 @@ class _Launcher(ABC): cluster environment, hardware, strategy, etc. """ + @property + @abstractmethod + def is_interactive_compatible(self) -> bool: + """Returns whether this launcher can work in interactive environments such as Jupyter notebooks.""" + @abstractmethod def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Launches the processes.""" diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 3b393f4a0c..d67f9e620a 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -49,6 +49,13 @@ class _SpawnLauncher(_Launcher): self._strategy = strategy self._start_method = "spawn" + @property + def is_interactive_compatible(self) -> bool: + # The start method 'spawn' is currently the only one that works with DDP and CUDA support + # The start method 'fork' is the only one supported in Jupyter environments but not compatible with CUDA + # For more context, see https://github.com/PyTorchLightning/pytorch-lightning/issues/7550 + return self._start_method == "fork" and self._strategy.root_device.type != "cuda" + def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. diff --git a/pytorch_lightning/strategies/launchers/subprocess_script.py b/pytorch_lightning/strategies/launchers/subprocess_script.py index a99a967a88..5a8632fb87 100644 --- a/pytorch_lightning/strategies/launchers/subprocess_script.py +++ b/pytorch_lightning/strategies/launchers/subprocess_script.py @@ -68,6 +68,10 @@ class _SubprocessScriptLauncher(_Launcher): num_nodes: The total number of nodes that participate in this process group. """ + @property + def is_interactive_compatible(self) -> bool: + return False + def __init__(self, cluster_environment: ClusterEnvironment, num_processes: int, num_nodes: int) -> None: super().__init__() self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 71acfc1011..b3e1bf3465 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -54,6 +54,10 @@ class _XLASpawnLauncher(_SpawnLauncher): super().__init__(strategy) self._start_method = "fork" + @property + def is_interactive_compatible(self) -> bool: + return True + def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index a62074b444..f2d27a249f 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -65,6 +65,7 @@ from pytorch_lightning.strategies import ( TPUSpawnStrategy, ) from pytorch_lightning.utilities import ( + _StrategyType, AMPType, device_parser, LightningEnum, @@ -734,19 +735,12 @@ class AcceleratorConnector: from pytorch_lightning.utilities import _IS_INTERACTIVE - # TODO move is_compatible logic to strategy API - interactive_compatible_strategy = ( - DataParallelStrategy.strategy_name, - DDPSpawnStrategy.strategy_name, - DDPSpawnShardedStrategy.strategy_name, - TPUSpawnStrategy.strategy_name, - ) - if _IS_INTERACTIVE and self.strategy.strategy_name not in interactive_compatible_strategy: + if _IS_INTERACTIVE and self.strategy.launcher and not self.strategy.launcher.is_interactive_compatible: raise MisconfigurationException( f"`Trainer(strategy={self.strategy.strategy_name!r})` or" f" `Trainer(accelerator={self.strategy.strategy_name!r})` is not compatible with an interactive" - " environment. Run your code as a script, or choose one of the compatible backends:" - f" {', '.join(interactive_compatible_strategy)}." + " environment. Run your code as a script, or choose one of the compatible strategies:" + f" Trainer(strategy=None|{'|'.join(_StrategyType.interactive_compatible_types())})." " In case you are spawning processes yourself, make sure to include the Trainer" " creation inside the worker function." ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 103fc87ecd..105b167a29 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -254,8 +254,6 @@ class _StrategyType(LightningEnum): """Returns a list containing interactive compatible _StrategyTypes.""" return [ _StrategyType.DP, - _StrategyType.DDP_SPAWN, - _StrategyType.DDP_SHARDED_SPAWN, _StrategyType.TPU_SPAWN, ] diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 998062f5aa..0e13b4af0f 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -20,6 +20,7 @@ import pytest import torch import torch.distributed +import pytorch_lightning from pytorch_lightning import Trainer from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator @@ -392,19 +393,31 @@ def test_dist_backend_accelerator_mapping(*_): assert trainer.strategy.local_rank == 0 -@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) @mock.patch("torch.cuda.device_count", return_value=2) -def test_ipython_incompatible_backend_error(*_): +def test_ipython_incompatible_backend_error(_, monkeypatch): + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): Trainer(strategy="ddp", gpus=2) with pytest.raises(MisconfigurationException, match=r"strategy='ddp2'\)`.*is not compatible"): Trainer(strategy="ddp2", gpus=2) + with pytest.raises(MisconfigurationException, match=r"strategy='ddp_spawn'\)`.*is not compatible"): + Trainer(strategy="ddp_spawn") -@mock.patch("pytorch_lightning.utilities._IS_INTERACTIVE", return_value=True) -def test_ipython_compatible_backend(*_): - Trainer(strategy="ddp_spawn", num_processes=2) + with pytest.raises(MisconfigurationException, match=r"strategy='ddp_sharded_spawn'\)`.*is not compatible"): + Trainer(strategy="ddp_sharded_spawn") + + with pytest.raises(MisconfigurationException, match=r"strategy='ddp'\)`.*is not compatible"): + # Edge case: AcceleratorConnector maps dp to ddp if accelerator != gpu + Trainer(strategy="dp") + + +@pytest.mark.parametrize("trainer_kwargs", [{}, dict(strategy="dp", accelerator="gpu"), dict(accelerator="tpu")]) +def test_ipython_compatible_backend(trainer_kwargs, monkeypatch): + monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True) + trainer = Trainer(**trainer_kwargs) + assert trainer.strategy.launcher is None or trainer.strategy.launcher.is_interactive_compatible @pytest.mark.parametrize(["accelerator", "plugin"], [("ddp_spawn", "ddp_sharded"), (None, "ddp_sharded")])