Update mypy (#11096)
This commit is contained in:
parent
cc42aa9401
commit
f37bd4677d
|
@ -150,7 +150,7 @@ class PrecisionPlugin(CheckpointHooks):
|
|||
"""Hook to run the optimizer step."""
|
||||
if isinstance(model, pl.LightningModule):
|
||||
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
|
||||
optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg]
|
||||
optimizer.step(closure=closure, **kwargs)
|
||||
|
||||
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
|
||||
if trainer.track_grad_norm == -1:
|
||||
|
|
|
@ -3,7 +3,6 @@ import os
|
|||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from signal import Signals
|
||||
from subprocess import call
|
||||
from types import FrameType
|
||||
from typing import Any, Callable, Dict, List, Set, Union
|
||||
|
@ -12,33 +11,38 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
# copied from signal.pyi
|
||||
_SIGNUM = Union[int, signal.Signals]
|
||||
_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None]
|
||||
|
||||
_SIGNAL_HANDLER_DICT = Dict[Signals, Union[Callable[[Signals, FrameType], Any], int, None]]
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HandlersCompose:
|
||||
def __init__(self, signal_handlers: Union[List[Callable], Callable]) -> None:
|
||||
def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None:
|
||||
if not isinstance(signal_handlers, list):
|
||||
signal_handlers = [signal_handlers]
|
||||
self.signal_handlers = signal_handlers
|
||||
|
||||
def __call__(self, signum: Signals, frame: FrameType) -> None:
|
||||
def __call__(self, signum: _SIGNUM, frame: FrameType) -> None:
|
||||
for signal_handler in self.signal_handlers:
|
||||
signal_handler(signum, frame)
|
||||
if isinstance(signal_handler, int):
|
||||
signal_handler = signal.getsignal(signal_handler)
|
||||
if callable(signal_handler):
|
||||
signal_handler(signum, frame)
|
||||
|
||||
|
||||
class SignalConnector:
|
||||
def __init__(self, trainer: "pl.Trainer") -> None:
|
||||
self.trainer = trainer
|
||||
self.trainer._terminate_gracefully = False
|
||||
self._original_handlers: _SIGNAL_HANDLER_DICT = {}
|
||||
self._original_handlers: Dict[_SIGNUM, _HANDLER] = {}
|
||||
|
||||
def register_signal_handlers(self) -> None:
|
||||
self._original_handlers = self._get_current_signal_handlers()
|
||||
|
||||
sigusr1_handlers: List[Callable] = []
|
||||
sigterm_handlers: List[Callable] = []
|
||||
sigusr1_handlers: List[_HANDLER] = []
|
||||
sigterm_handlers: List[_HANDLER] = []
|
||||
|
||||
if _fault_tolerant_training():
|
||||
sigterm_handlers.append(self.fault_tolerant_sigterm_handler_fn)
|
||||
|
@ -57,7 +61,7 @@ class SignalConnector:
|
|||
if sigterm_handlers and not self._has_already_handler(signal.SIGTERM):
|
||||
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
|
||||
|
||||
def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
|
||||
def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
|
||||
if self.trainer.is_global_zero:
|
||||
# save weights
|
||||
log.info("handling SIGUSR1")
|
||||
|
@ -88,22 +92,22 @@ class SignalConnector:
|
|||
if self.trainer.logger:
|
||||
self.trainer.logger.finalize("finished")
|
||||
|
||||
def fault_tolerant_sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
|
||||
def fault_tolerant_sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
|
||||
log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.")
|
||||
self.trainer._terminate_gracefully = True
|
||||
|
||||
def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
|
||||
def sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
|
||||
log.info("bypassing sigterm")
|
||||
|
||||
def teardown(self) -> None:
|
||||
"""Restores the signals that were previsouly configured before :class:`SignalConnector` replaced them."""
|
||||
for signum, handler in self._original_handlers.items():
|
||||
if handler is not None:
|
||||
signal.signal(signum, handler)
|
||||
signal.signal(signum, handler) # type: ignore[arg-type]
|
||||
self._original_handlers = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_current_signal_handlers() -> _SIGNAL_HANDLER_DICT:
|
||||
def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]:
|
||||
"""Collects the currently assigned signal handlers."""
|
||||
valid_signals = SignalConnector._valid_signals()
|
||||
if not _IS_WINDOWS:
|
||||
|
@ -112,7 +116,7 @@ class SignalConnector:
|
|||
return {signum: signal.getsignal(signum) for signum in valid_signals}
|
||||
|
||||
@staticmethod
|
||||
def _valid_signals() -> Set[Signals]:
|
||||
def _valid_signals() -> Set[signal.Signals]:
|
||||
"""Returns all valid signals supported on the current platform.
|
||||
|
||||
Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for
|
||||
|
@ -138,13 +142,13 @@ class SignalConnector:
|
|||
return sys.platform == "win32"
|
||||
|
||||
@staticmethod
|
||||
def _has_already_handler(signum: Signals) -> bool:
|
||||
def _has_already_handler(signum: _SIGNUM) -> bool:
|
||||
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
|
||||
|
||||
@staticmethod
|
||||
def _register_signal(signum: Signals, handlers: HandlersCompose) -> None:
|
||||
def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None:
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signum, handlers)
|
||||
signal.signal(signum, handlers) # type: ignore[arg-type]
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = self.__dict__.copy()
|
||||
|
|
|
@ -34,6 +34,8 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O
|
|||
raise ValueError("Expected a parent")
|
||||
|
||||
instance_attr = getattr(instance, method_name, None)
|
||||
if instance_attr is None:
|
||||
return False
|
||||
# `functools.wraps()` support
|
||||
if hasattr(instance_attr, "__wrapped__"):
|
||||
instance_attr = instance_attr.__wrapped__
|
||||
|
|
|
@ -3,7 +3,7 @@ codecov>=2.1
|
|||
pytest>=6.0
|
||||
pytest-rerunfailures>=10.2
|
||||
twine==3.2
|
||||
mypy==0.910
|
||||
mypy>=0.920
|
||||
flake8>=3.9.2
|
||||
pre-commit>=1.0
|
||||
|
||||
|
|
Loading…
Reference in New Issue