Update mypy (#11096)

This commit is contained in:
Carlos Mocholí 2021-12-16 17:53:12 +01:00 committed by GitHub
parent cc42aa9401
commit f37bd4677d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 20 deletions

View File

@ -150,7 +150,7 @@ class PrecisionPlugin(CheckpointHooks):
"""Hook to run the optimizer step."""
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg]
optimizer.step(closure=closure, **kwargs)
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:

View File

@ -3,7 +3,6 @@ import os
import signal
import sys
import threading
from signal import Signals
from subprocess import call
from types import FrameType
from typing import Any, Callable, Dict, List, Set, Union
@ -12,33 +11,38 @@ import pytorch_lightning as pl
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS
log = logging.getLogger(__name__)
# copied from signal.pyi
_SIGNUM = Union[int, signal.Signals]
_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None]
_SIGNAL_HANDLER_DICT = Dict[Signals, Union[Callable[[Signals, FrameType], Any], int, None]]
log = logging.getLogger(__name__)
class HandlersCompose:
def __init__(self, signal_handlers: Union[List[Callable], Callable]) -> None:
def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None:
if not isinstance(signal_handlers, list):
signal_handlers = [signal_handlers]
self.signal_handlers = signal_handlers
def __call__(self, signum: Signals, frame: FrameType) -> None:
def __call__(self, signum: _SIGNUM, frame: FrameType) -> None:
for signal_handler in self.signal_handlers:
signal_handler(signum, frame)
if isinstance(signal_handler, int):
signal_handler = signal.getsignal(signal_handler)
if callable(signal_handler):
signal_handler(signum, frame)
class SignalConnector:
def __init__(self, trainer: "pl.Trainer") -> None:
self.trainer = trainer
self.trainer._terminate_gracefully = False
self._original_handlers: _SIGNAL_HANDLER_DICT = {}
self._original_handlers: Dict[_SIGNUM, _HANDLER] = {}
def register_signal_handlers(self) -> None:
self._original_handlers = self._get_current_signal_handlers()
sigusr1_handlers: List[Callable] = []
sigterm_handlers: List[Callable] = []
sigusr1_handlers: List[_HANDLER] = []
sigterm_handlers: List[_HANDLER] = []
if _fault_tolerant_training():
sigterm_handlers.append(self.fault_tolerant_sigterm_handler_fn)
@ -57,7 +61,7 @@ class SignalConnector:
if sigterm_handlers and not self._has_already_handler(signal.SIGTERM):
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
if self.trainer.is_global_zero:
# save weights
log.info("handling SIGUSR1")
@ -88,22 +92,22 @@ class SignalConnector:
if self.trainer.logger:
self.trainer.logger.finalize("finished")
def fault_tolerant_sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
def fault_tolerant_sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.")
self.trainer._terminate_gracefully = True
def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
def sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
log.info("bypassing sigterm")
def teardown(self) -> None:
"""Restores the signals that were previsouly configured before :class:`SignalConnector` replaced them."""
for signum, handler in self._original_handlers.items():
if handler is not None:
signal.signal(signum, handler)
signal.signal(signum, handler) # type: ignore[arg-type]
self._original_handlers = {}
@staticmethod
def _get_current_signal_handlers() -> _SIGNAL_HANDLER_DICT:
def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]:
"""Collects the currently assigned signal handlers."""
valid_signals = SignalConnector._valid_signals()
if not _IS_WINDOWS:
@ -112,7 +116,7 @@ class SignalConnector:
return {signum: signal.getsignal(signum) for signum in valid_signals}
@staticmethod
def _valid_signals() -> Set[Signals]:
def _valid_signals() -> Set[signal.Signals]:
"""Returns all valid signals supported on the current platform.
Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for
@ -138,13 +142,13 @@ class SignalConnector:
return sys.platform == "win32"
@staticmethod
def _has_already_handler(signum: Signals) -> bool:
def _has_already_handler(signum: _SIGNUM) -> bool:
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
@staticmethod
def _register_signal(signum: Signals, handlers: HandlersCompose) -> None:
def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None:
if threading.current_thread() is threading.main_thread():
signal.signal(signum, handlers)
signal.signal(signum, handlers) # type: ignore[arg-type]
def __getstate__(self) -> Dict:
state = self.__dict__.copy()

View File

@ -34,6 +34,8 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O
raise ValueError("Expected a parent")
instance_attr = getattr(instance, method_name, None)
if instance_attr is None:
return False
# `functools.wraps()` support
if hasattr(instance_attr, "__wrapped__"):
instance_attr = instance_attr.__wrapped__

View File

@ -3,7 +3,7 @@ codecov>=2.1
pytest>=6.0
pytest-rerunfailures>=10.2
twine==3.2
mypy==0.910
mypy>=0.920
flake8>=3.9.2
pre-commit>=1.0