From e4c7c5f4bc3771642761b81f77ee013ec896598f Mon Sep 17 00:00:00 2001 From: himkt Date: Fri, 15 Jul 2022 21:03:08 +0900 Subject: [PATCH] Fix mypy errors attributed to pytorch_lightning/loops/epoch/training_epoch_loop.py (#13555) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pyproject.toml | 1 - .../loops/epoch/training_epoch_loop.py | 24 +++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e886795e23..0ddadd2b29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,6 @@ module = [ "pytorch_lightning.loggers.neptune", "pytorch_lightning.loggers.tensorboard", "pytorch_lightning.loggers.wandb", - "pytorch_lightning.loops.epoch.training_epoch_loop", "pytorch_lightning.profilers.advanced", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 630b3acf17..de07acdc90 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -13,7 +13,7 @@ # limitations under the License. import math from collections import defaultdict, OrderedDict -from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union +from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload, Tuple, Union import numpy as np import torch @@ -286,7 +286,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available - self._dataloader_state_dict = state_dict.get("dataloader_state_dict") + self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0) def _run_validation(self) -> None: @@ -331,7 +331,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): ) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]: """Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook.""" if not batch_output: - return [] + return [] # type: ignore[return-value] # convert optimizer dicts to list if lightning_module.automatic_optimization: @@ -373,7 +373,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): """Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook.""" # `batch_outputs` (plural) is the same as `epoch_end_output` (singular) if not batch_outputs: - return [] + return [] # type: ignore[return-value] # convert optimizer dicts to list if lightning_module.automatic_optimization: @@ -455,8 +455,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): if config.interval == interval and current_idx % config.frequency == 0: monitor_val = None if config.reduce_on_plateau: - # If instance of ReduceLROnPlateau, we need a monitor monitor_key = config.monitor + assert monitor_key is not None monitor_val = self._get_monitor_value(monitor_key) if monitor_val is None: if config.strict: @@ -485,11 +485,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): ) self.scheduler_progress.increment_completed() - def _get_monitor_value(self, key: str) -> Any: + def _get_monitor_value(self, key: str) -> Optional[Any]: # this is a separate method to aid in testing return self.trainer.callback_metrics.get(key) - def _should_check_val_epoch(self): + def _should_check_val_epoch(self) -> bool: return self.trainer.enable_validation and ( self.trainer.check_val_every_n_epoch is None or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 @@ -531,7 +531,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None: if self._dataloader_state_dict: data_fetcher.dataloader.load_state_dict(self._dataloader_state_dict) - self._dataloader_state_dict = None + self._dataloader_state_dict = {} def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict: """Helper method to build the arguments for the current step. @@ -564,12 +564,12 @@ def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> @overload -def _recursive_unpad(nested: Any, value: Optional[Any] = None) -> Any: +def _recursive_unpad(nested: List[Any], value: Optional[Any] = None) -> List[Any]: ... @overload -def _recursive_unpad(nested: List[Any], value: Optional[Any] = None) -> List[Any]: +def _recursive_unpad(nested: Any, value: Optional[Any] = None) -> Any: ... @@ -587,7 +587,7 @@ def _recursive_unpad(nested: Union[Any, List[Any]], value: Optional[Any] = None) return [_recursive_unpad(item, value) for item in nested if item != value] -def _recursive_pad(nested: List[Any], fill_value: Optional[Any] = None) -> np.array: +def _recursive_pad(nested: List[Any], fill_value: Optional[Any] = None) -> np.ndarray: """Pads a jagged nested list of lists with the given value such that a proper multi-dimensional array can be formed with rectangular shape. The padding appends to the incomplete lists. @@ -618,7 +618,7 @@ def _get_max_shape(array: List[Any]) -> List[int]: >>> _get_max_shape([[], [[1], [2]], []]) [3, 2, 1] """ - dimensions = defaultdict(int) + dimensions: DefaultDict[int, int] = defaultdict(int) for level, length in _get_dimensions(array): dimensions[level] = max(dimensions[level], length) return [value for _, value in sorted(dimensions.items())]