Fix mypy errors attributed to pytorch_lightning/loops/epoch/training_epoch_loop.py (#13555)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
8355ba1260
commit
e4c7c5f4bc
|
@ -63,7 +63,6 @@ module = [
|
|||
"pytorch_lightning.loggers.neptune",
|
||||
"pytorch_lightning.loggers.tensorboard",
|
||||
"pytorch_lightning.loggers.wandb",
|
||||
"pytorch_lightning.loops.epoch.training_epoch_loop",
|
||||
"pytorch_lightning.profilers.advanced",
|
||||
"pytorch_lightning.profilers.base",
|
||||
"pytorch_lightning.profilers.pytorch",
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import math
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union
|
||||
from typing import Any, DefaultDict, Dict, Generator, List, Optional, overload, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -286,7 +286,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
|
||||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
# cache the dataloader state dict until the dataloader objects are available
|
||||
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
|
||||
self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {})
|
||||
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
|
||||
|
||||
def _run_validation(self) -> None:
|
||||
|
@ -331,7 +331,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
) -> Union[List[List[Dict[str, Any]]], List[Dict[str, Any]]]:
|
||||
"""Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook."""
|
||||
if not batch_output:
|
||||
return []
|
||||
return [] # type: ignore[return-value]
|
||||
|
||||
# convert optimizer dicts to list
|
||||
if lightning_module.automatic_optimization:
|
||||
|
@ -373,7 +373,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
"""Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook."""
|
||||
# `batch_outputs` (plural) is the same as `epoch_end_output` (singular)
|
||||
if not batch_outputs:
|
||||
return []
|
||||
return [] # type: ignore[return-value]
|
||||
|
||||
# convert optimizer dicts to list
|
||||
if lightning_module.automatic_optimization:
|
||||
|
@ -455,8 +455,8 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
if config.interval == interval and current_idx % config.frequency == 0:
|
||||
monitor_val = None
|
||||
if config.reduce_on_plateau:
|
||||
# If instance of ReduceLROnPlateau, we need a monitor
|
||||
monitor_key = config.monitor
|
||||
assert monitor_key is not None
|
||||
monitor_val = self._get_monitor_value(monitor_key)
|
||||
if monitor_val is None:
|
||||
if config.strict:
|
||||
|
@ -485,11 +485,11 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
)
|
||||
self.scheduler_progress.increment_completed()
|
||||
|
||||
def _get_monitor_value(self, key: str) -> Any:
|
||||
def _get_monitor_value(self, key: str) -> Optional[Any]:
|
||||
# this is a separate method to aid in testing
|
||||
return self.trainer.callback_metrics.get(key)
|
||||
|
||||
def _should_check_val_epoch(self):
|
||||
def _should_check_val_epoch(self) -> bool:
|
||||
return self.trainer.enable_validation and (
|
||||
self.trainer.check_val_every_n_epoch is None
|
||||
or (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
|
||||
|
@ -531,7 +531,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
|
||||
if self._dataloader_state_dict:
|
||||
data_fetcher.dataloader.load_state_dict(self._dataloader_state_dict)
|
||||
self._dataloader_state_dict = None
|
||||
self._dataloader_state_dict = {}
|
||||
|
||||
def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> OrderedDict:
|
||||
"""Helper method to build the arguments for the current step.
|
||||
|
@ -564,12 +564,12 @@ def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) ->
|
|||
|
||||
|
||||
@overload
|
||||
def _recursive_unpad(nested: Any, value: Optional[Any] = None) -> Any:
|
||||
def _recursive_unpad(nested: List[Any], value: Optional[Any] = None) -> List[Any]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def _recursive_unpad(nested: List[Any], value: Optional[Any] = None) -> List[Any]:
|
||||
def _recursive_unpad(nested: Any, value: Optional[Any] = None) -> Any:
|
||||
...
|
||||
|
||||
|
||||
|
@ -587,7 +587,7 @@ def _recursive_unpad(nested: Union[Any, List[Any]], value: Optional[Any] = None)
|
|||
return [_recursive_unpad(item, value) for item in nested if item != value]
|
||||
|
||||
|
||||
def _recursive_pad(nested: List[Any], fill_value: Optional[Any] = None) -> np.array:
|
||||
def _recursive_pad(nested: List[Any], fill_value: Optional[Any] = None) -> np.ndarray:
|
||||
"""Pads a jagged nested list of lists with the given value such that a proper multi-dimensional array can be
|
||||
formed with rectangular shape. The padding appends to the incomplete lists.
|
||||
|
||||
|
@ -618,7 +618,7 @@ def _get_max_shape(array: List[Any]) -> List[int]:
|
|||
>>> _get_max_shape([[], [[1], [2]], []])
|
||||
[3, 2, 1]
|
||||
"""
|
||||
dimensions = defaultdict(int)
|
||||
dimensions: DefaultDict[int, int] = defaultdict(int)
|
||||
for level, length in _get_dimensions(array):
|
||||
dimensions[level] = max(dimensions[level], length)
|
||||
return [value for _, value in sorted(dimensions.items())]
|
||||
|
|
Loading…
Reference in New Issue