diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 30c3a21025..aa96da63ad 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -25,6 +25,7 @@ import torch.backends.cudnn import torch.multiprocessing as mp from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.strategies.launchers.multiprocessing import ( @@ -83,12 +84,14 @@ class _MultiProcessingLauncher(_Launcher): self._already_fit = False @property + @override def is_interactive_compatible(self) -> bool: # The start method 'spawn' is not supported in interactive environments # The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550 return self._start_method == "fork" + @override def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Launches processes that run the given function in parallel. @@ -252,6 +255,7 @@ class _MultiProcessingLauncher(_Launcher): callback_metrics = extra["callback_metrics"] trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) + @override def kill(self, signum: _SIGNUM) -> None: for proc in self.procs: if proc.is_alive() and proc.pid is not None: diff --git a/src/lightning/pytorch/strategies/launchers/subprocess_script.py b/src/lightning/pytorch/strategies/launchers/subprocess_script.py index 5748c0def1..03dbbc5236 100644 --- a/src/lightning/pytorch/strategies/launchers/subprocess_script.py +++ b/src/lightning/pytorch/strategies/launchers/subprocess_script.py @@ -17,6 +17,7 @@ import subprocess from typing import Any, Callable, List, Optional from lightning_utilities.core.imports import RequirementCache +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.plugins import ClusterEnvironment @@ -79,9 +80,11 @@ class _SubprocessScriptLauncher(_Launcher): self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher @property + @override def is_interactive_compatible(self) -> bool: return False + @override def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Creates new processes, then calls the given function. @@ -101,6 +104,7 @@ class _SubprocessScriptLauncher(_Launcher): _set_num_threads_if_needed(num_processes=self.num_processes) return function(*args, **kwargs) + @override def kill(self, signum: _SIGNUM) -> None: for proc in self.procs: log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}") diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 4b29e5eefa..d23f3d896f 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -16,6 +16,7 @@ import queue from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch.multiprocessing as mp +from typing_extensions import override from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt from lightning.fabric.strategies.launchers.xla import _rank_teardown @@ -55,9 +56,11 @@ class _XLALauncher(_MultiProcessingLauncher): super().__init__(strategy=strategy, start_method="fork") @property + @override def is_interactive_compatible(self) -> bool: return True + @override def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Launches processes that run the given function in parallel. @@ -116,6 +119,7 @@ class _XLALauncher(_MultiProcessingLauncher): self._recover_results_in_main_process(worker_output, trainer) return worker_output.trainer_results + @override def _wrapping_function( self, # XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing @@ -147,6 +151,7 @@ class _XLALauncher(_MultiProcessingLauncher): _rank_teardown(self._strategy.local_rank) + @override def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]: rank_zero_debug("Collecting results from rank 0 process.") checkpoint_callback = trainer.checkpoint_callback