Fix mypy errors attributed to pytorch_lightning/loops/epoch/training_epoch_loop.py (#13555)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
himkt 2022-07-15 21:03:08 +09:00 committed by GitHub
parent 8355ba1260
commit e4c7c5f4bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 13 deletions

View File

@ -63,7 +63,6 @@ module = [
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.loggers.tensorboard",
"pytorch_lightning.loggers.wandb",
"pytorch_lightning.loops.epoch.training_epoch_loop",
"pytorch_lightning.profilers.advanced",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",

View File

@ -13,7 +13,7 @@
# limitations under the License.
import math
from collections import defaultdict, OrderedDict
from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union
from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload, Tuple, Union
import numpy as np
import torch
@ -286,7 +286,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
def _run_validation(self) -> None:
@ -331,7 +331,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]:
"""Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook."""
if not batch_output:
return []
return [] # type: ignore[return-value]
# convert optimizer dicts to list
if lightning_module.automatic_optimization:
@ -373,7 +373,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
"""Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook."""
# `batch_outputs` (plural) is the same as `epoch_end_output` (singular)
if not batch_outputs:
return []
return [] # type: ignore[return-value]
# convert optimizer dicts to list
if lightning_module.automatic_optimization:
@ -455,8 +455,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
if config.interval == interval and current_idx % config.frequency == 0:
monitor_val = None
if config.reduce_on_plateau:
# If instance of ReduceLROnPlateau, we need a monitor
monitor_key = config.monitor
assert monitor_key is not None
monitor_val = self._get_monitor_value(monitor_key)
if monitor_val is None:
if config.strict:
@ -485,11 +485,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
)
self.scheduler_progress.increment_completed()
def _get_monitor_value(self, key: str) -> Any:
def _get_monitor_value(self, key: str) -> Optional[Any]:
# this is a separate method to aid in testing
return self.trainer.callback_metrics.get(key)
def _should_check_val_epoch(self):
def _should_check_val_epoch(self) -> bool:
return self.trainer.enable_validation and (
self.trainer.check_val_every_n_epoch is None
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
@ -531,7 +531,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
if self._dataloader_state_dict:
data_fetcher.dataloader.load_state_dict(self._dataloader_state_dict)
self._dataloader_state_dict = None
self._dataloader_state_dict = {}
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
"""Helper method to build the arguments for the current step.
@ -564,12 +564,12 @@ def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) ->
@overload
def _recursive_unpad(nested: Any, value: Optional[Any] = None) -> Any:
def _recursive_unpad(nested: List[Any], value: Optional[Any] = None) -> List[Any]:
...
@overload
def _recursive_unpad(nested: List[Any], value: Optional[Any] = None) -> List[Any]:
def _recursive_unpad(nested: Any, value: Optional[Any] = None) -> Any:
...
@ -587,7 +587,7 @@ def _recursive_unpad(nested: Union[Any, List[Any]], value: Optional[Any] = None)
return [_recursive_unpad(item, value) for item in nested if item != value]
def _recursive_pad(nested: List[Any], fill_value: Optional[Any] = None) -> np.array:
def _recursive_pad(nested: List[Any], fill_value: Optional[Any] = None) -> np.ndarray:
"""Pads a jagged nested list of lists with the given value such that a proper multi-dimensional array can be
formed with rectangular shape. The padding appends to the incomplete lists.
@ -618,7 +618,7 @@ def _get_max_shape(array: List[Any]) -> List[int]:
>>> _get_max_shape([[], [[1], [2]], []])
[3, 2, 1]
"""
dimensions = defaultdict(int)
dimensions: DefaultDict[int, int] = defaultdict(int)
for level, length in _get_dimensions(array):
dimensions[level] = max(dimensions[level], length)
return [value for _, value in sorted(dimensions.items())]