Add `@override` for subclasses of PyTorch `_Launcher` (#18922)
This commit is contained in:
parent
90d3d1ca7e
commit
809e952434
|
@ -25,6 +25,7 @@ import torch.backends.cudnn
|
|||
import torch.multiprocessing as mp
|
||||
from lightning_utilities.core.apply_func import apply_to_collection
|
||||
from torch import Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.strategies.launchers.multiprocessing import (
|
||||
|
@ -83,12 +84,14 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
self._already_fit = False
|
||||
|
||||
@property
|
||||
@override
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
# The start method 'spawn' is not supported in interactive environments
|
||||
# The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA
|
||||
# initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550
|
||||
return self._start_method == "fork"
|
||||
|
||||
@override
|
||||
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
|
||||
"""Launches processes that run the given function in parallel.
|
||||
|
||||
|
@ -252,6 +255,7 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
callback_metrics = extra["callback_metrics"]
|
||||
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
|
||||
|
||||
@override
|
||||
def kill(self, signum: _SIGNUM) -> None:
|
||||
for proc in self.procs:
|
||||
if proc.is_alive() and proc.pid is not None:
|
||||
|
|
|
@ -17,6 +17,7 @@ import subprocess
|
|||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.plugins import ClusterEnvironment
|
||||
|
@ -79,9 +80,11 @@ class _SubprocessScriptLauncher(_Launcher):
|
|||
self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher
|
||||
|
||||
@property
|
||||
@override
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
|
||||
"""Creates new processes, then calls the given function.
|
||||
|
||||
|
@ -101,6 +104,7 @@ class _SubprocessScriptLauncher(_Launcher):
|
|||
_set_num_threads_if_needed(num_processes=self.num_processes)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
@override
|
||||
def kill(self, signum: _SIGNUM) -> None:
|
||||
for proc in self.procs:
|
||||
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
|
||||
|
|
|
@ -16,6 +16,7 @@ import queue
|
|||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt
|
||||
from lightning.fabric.strategies.launchers.xla import _rank_teardown
|
||||
|
@ -55,9 +56,11 @@ class _XLALauncher(_MultiProcessingLauncher):
|
|||
super().__init__(strategy=strategy, start_method="fork")
|
||||
|
||||
@property
|
||||
@override
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
return True
|
||||
|
||||
@override
|
||||
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
|
||||
"""Launches processes that run the given function in parallel.
|
||||
|
||||
|
@ -116,6 +119,7 @@ class _XLALauncher(_MultiProcessingLauncher):
|
|||
self._recover_results_in_main_process(worker_output, trainer)
|
||||
return worker_output.trainer_results
|
||||
|
||||
@override
|
||||
def _wrapping_function(
|
||||
self,
|
||||
# XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
|
||||
|
@ -147,6 +151,7 @@ class _XLALauncher(_MultiProcessingLauncher):
|
|||
|
||||
_rank_teardown(self._strategy.local_rank)
|
||||
|
||||
@override
|
||||
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
|
||||
rank_zero_debug("Collecting results from rank 0 process.")
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
|
|
Loading…
Reference in New Issue