Add `@override` for subclasses of PyTorch `_Launcher` (#18922)

This commit is contained in:
Victor Prins 2023-11-03 02:11:39 +01:00 committed by GitHub
parent 90d3d1ca7e
commit 809e952434
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 0 deletions

View File

@ -25,6 +25,7 @@ import torch.backends.cudnn
import torch.multiprocessing as mp
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.strategies.launchers.multiprocessing import (
@ -83,12 +84,14 @@ class _MultiProcessingLauncher(_Launcher):
self._already_fit = False
@property
@override
def is_interactive_compatible(self) -> bool:
# The start method 'spawn' is not supported in interactive environments
# The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA
# initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550
return self._start_method == "fork"
@override
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Launches processes that run the given function in parallel.
@ -252,6 +255,7 @@ class _MultiProcessingLauncher(_Launcher):
callback_metrics = extra["callback_metrics"]
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
if proc.is_alive() and proc.pid is not None:

View File

@ -17,6 +17,7 @@ import subprocess
from typing import Any, Callable, List, Optional
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.plugins import ClusterEnvironment
@ -79,9 +80,11 @@ class _SubprocessScriptLauncher(_Launcher):
self.procs: List[subprocess.Popen] = [] # launched child subprocesses, does not include the launcher
@property
@override
def is_interactive_compatible(self) -> bool:
return False
@override
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Creates new processes, then calls the given function.
@ -101,6 +104,7 @@ class _SubprocessScriptLauncher(_Launcher):
_set_num_threads_if_needed(num_processes=self.num_processes)
return function(*args, **kwargs)
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")

View File

@ -16,6 +16,7 @@ import queue
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import torch.multiprocessing as mp
from typing_extensions import override
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt
from lightning.fabric.strategies.launchers.xla import _rank_teardown
@ -55,9 +56,11 @@ class _XLALauncher(_MultiProcessingLauncher):
super().__init__(strategy=strategy, start_method="fork")
@property
@override
def is_interactive_compatible(self) -> bool:
return True
@override
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Launches processes that run the given function in parallel.
@ -116,6 +119,7 @@ class _XLALauncher(_MultiProcessingLauncher):
self._recover_results_in_main_process(worker_output, trainer)
return worker_output.trainer_results
@override
def _wrapping_function(
self,
# XLA's multiprocessing returns the global index, not the local index as torch's multiprocessing
@ -147,6 +151,7 @@ class _XLALauncher(_MultiProcessingLauncher):
_rank_teardown(self._strategy.local_rank)
@override
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
rank_zero_debug("Collecting results from rank 0 process.")
checkpoint_callback = trainer.checkpoint_callback