diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 109be55b8d..0472ab42c6 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -150,7 +150,7 @@ class PrecisionPlugin(CheckpointHooks): """Hook to run the optimizer step.""" if isinstance(model, pl.LightningModule): closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) - optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg] + optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if trainer.track_grad_norm == -1: diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index baa17be5fa..daecf5a419 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -3,7 +3,6 @@ import os import signal import sys import threading -from signal import Signals from subprocess import call from types import FrameType from typing import Any, Callable, Dict, List, Set, Union @@ -12,33 +11,38 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS -log = logging.getLogger(__name__) +# copied from signal.pyi +_SIGNUM = Union[int, signal.Signals] +_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None] -_SIGNAL_HANDLER_DICT = Dict[Signals, Union[Callable[[Signals, FrameType], Any], int, None]] +log = logging.getLogger(__name__) class HandlersCompose: - def __init__(self, signal_handlers: Union[List[Callable], Callable]) -> None: + def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None: if not isinstance(signal_handlers, list): signal_handlers = [signal_handlers] self.signal_handlers = signal_handlers - def __call__(self, signum: Signals, frame: FrameType) -> None: + def __call__(self, signum: _SIGNUM, frame: FrameType) -> None: for signal_handler in self.signal_handlers: - signal_handler(signum, frame) + if isinstance(signal_handler, int): + signal_handler = signal.getsignal(signal_handler) + if callable(signal_handler): + signal_handler(signum, frame) class SignalConnector: def __init__(self, trainer: "pl.Trainer") -> None: self.trainer = trainer self.trainer._terminate_gracefully = False - self._original_handlers: _SIGNAL_HANDLER_DICT = {} + self._original_handlers: Dict[_SIGNUM, _HANDLER] = {} def register_signal_handlers(self) -> None: self._original_handlers = self._get_current_signal_handlers() - sigusr1_handlers: List[Callable] = [] - sigterm_handlers: List[Callable] = [] + sigusr1_handlers: List[_HANDLER] = [] + sigterm_handlers: List[_HANDLER] = [] if _fault_tolerant_training(): sigterm_handlers.append(self.fault_tolerant_sigterm_handler_fn) @@ -57,7 +61,7 @@ class SignalConnector: if sigterm_handlers and not self._has_already_handler(signal.SIGTERM): self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) - def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: + def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None: if self.trainer.is_global_zero: # save weights log.info("handling SIGUSR1") @@ -88,22 +92,22 @@ class SignalConnector: if self.trainer.logger: self.trainer.logger.finalize("finished") - def fault_tolerant_sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: + def fault_tolerant_sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None: log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.") self.trainer._terminate_gracefully = True - def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: + def sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None: log.info("bypassing sigterm") def teardown(self) -> None: """Restores the signals that were previsouly configured before :class:`SignalConnector` replaced them.""" for signum, handler in self._original_handlers.items(): if handler is not None: - signal.signal(signum, handler) + signal.signal(signum, handler) # type: ignore[arg-type] self._original_handlers = {} @staticmethod - def _get_current_signal_handlers() -> _SIGNAL_HANDLER_DICT: + def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]: """Collects the currently assigned signal handlers.""" valid_signals = SignalConnector._valid_signals() if not _IS_WINDOWS: @@ -112,7 +116,7 @@ class SignalConnector: return {signum: signal.getsignal(signum) for signum in valid_signals} @staticmethod - def _valid_signals() -> Set[Signals]: + def _valid_signals() -> Set[signal.Signals]: """Returns all valid signals supported on the current platform. Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for @@ -138,13 +142,13 @@ class SignalConnector: return sys.platform == "win32" @staticmethod - def _has_already_handler(signum: Signals) -> bool: + def _has_already_handler(signum: _SIGNUM) -> bool: return signal.getsignal(signum) not in (None, signal.SIG_DFL) @staticmethod - def _register_signal(signum: Signals, handlers: HandlersCompose) -> None: + def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None: if threading.current_thread() is threading.main_thread(): - signal.signal(signum, handlers) + signal.signal(signum, handlers) # type: ignore[arg-type] def __getstate__(self) -> Dict: state = self.__dict__.copy() diff --git a/pytorch_lightning/utilities/model_helpers.py b/pytorch_lightning/utilities/model_helpers.py index bb48b481e6..66ad264355 100644 --- a/pytorch_lightning/utilities/model_helpers.py +++ b/pytorch_lightning/utilities/model_helpers.py @@ -34,6 +34,8 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O raise ValueError("Expected a parent") instance_attr = getattr(instance, method_name, None) + if instance_attr is None: + return False # `functools.wraps()` support if hasattr(instance_attr, "__wrapped__"): instance_attr = instance_attr.__wrapped__ diff --git a/requirements/test.txt b/requirements/test.txt index 2f2228126a..941b53dc8c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -3,7 +3,7 @@ codecov>=2.1 pytest>=6.0 pytest-rerunfailures>=10.2 twine==3.2 -mypy==0.910 +mypy>=0.920 flake8>=3.9.2 pre-commit>=1.0