Fix mypy errors attributed to `pytorch_lightning.core.module.py` (#13603)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Justin Goheen 2022-08-19 07:03:43 -04:00 committed by GitHub
parent c9b3cda0e0
commit 090bbc8605
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 53 deletions

View File

@ -52,7 +52,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.module",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.profilers.base",

View File

@ -37,7 +37,7 @@ class DeviceDtypeModuleMixin(Module):
raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).")
@property
def device(self) -> Union[str, torch.device]:
def device(self) -> torch.device:
device = self._device
# make this more explicit to always include the index

View File

@ -22,7 +22,7 @@ import warnings
import weakref
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, overload, Sequence, Tuple, Union
import torch
from torch import ScriptModule, Tensor
@ -47,12 +47,20 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
from pytorch_lightning.utilities.types import (
_METRIC_COLLECTION,
EPOCH_OUTPUT,
LRSchedulerPLType,
LRSchedulerTypeUnion,
STEP_OUTPUT,
)
from pytorch_lightning.utilities.warnings import WarningCache
warning_cache = WarningCache()
log = logging.getLogger(__name__)
MODULE_OPTIMIZERS = Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]
class LightningModule(
DeviceDtypeModuleMixin,
@ -104,7 +112,7 @@ class LightningModule(
self._current_fx_name: Optional[str] = None
self._automatic_optimization: bool = True
self._truncated_bptt_steps: int = 0
self._param_requires_grad_state = {}
self._param_requires_grad_state: Dict[str, bool] = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
# TODO: remove in 1.8
@ -121,14 +129,10 @@ class LightningModule(
...
@overload
def optimizers(
self, use_pl_optimizer: bool
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS:
...
def optimizers(
self, use_pl_optimizer: bool = True
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
"""Returns the optimizer(s) that are being used during training. Useful for manual optimization.
Args:
@ -140,7 +144,7 @@ class LightningModule(
A single optimizer, or a list of optimizers in case multiple ones are present.
"""
if use_pl_optimizer:
opts = list(self.trainer.strategy._lightning_optimizers.values())
opts: MODULE_OPTIMIZERS = list(self.trainer.strategy._lightning_optimizers.values())
else:
opts = self.trainer.optimizers
@ -150,7 +154,7 @@ class LightningModule(
# multiple opts
return opts
def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRSchedulerTypeUnion]]]:
def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]:
"""Returns the learning rate scheduler(s) that are being used during training. Useful for manual
optimization.
@ -162,7 +166,7 @@ class LightningModule(
return None
# ignore other keys "interval", "frequency", etc.
lr_schedulers = [config.scheduler for config in self.trainer.lr_scheduler_configs]
lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs]
# single scheduler
if len(lr_schedulers) == 1:
@ -175,13 +179,13 @@ class LightningModule(
def trainer(self) -> "pl.Trainer":
if not self._running_torchscript and self._trainer is None:
raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
return self._trainer
return self._trainer # type: ignore[return-value]
@trainer.setter
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
for v in self.children():
if isinstance(v, LightningModule):
v.trainer = trainer
v.trainer = trainer # type: ignore[assignment]
if trainer is not None and not isinstance(trainer, weakref.ProxyTypes):
trainer = weakref.proxy(trainer)
self._trainer = trainer
@ -228,7 +232,7 @@ class LightningModule(
return self.trainer.local_rank if self._trainer else 0
@property
def on_gpu(self):
def on_gpu(self) -> bool:
"""Returns ``True`` if this model is currently located on a GPU.
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
@ -264,7 +268,7 @@ class LightningModule(
# this should match the implementation of `trainer.logger`
# we don't reuse it so we can properly set the deprecation stacklevel
if self._trainer is None:
return
return None
loggers = self.trainer.loggers
if len(loggers) == 0:
return None
@ -287,15 +291,15 @@ class LightningModule(
"""Reference to the list of loggers in the Trainer."""
return self.trainer.loggers if self._trainer else []
def _call_batch_hook(self, hook_name, *args) -> Any:
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
if self._trainer:
datahook_selector = self._trainer._data_connector._datahook_selector
obj = datahook_selector.get_instance(hook_name)
trainer_method = (
self._trainer._call_lightning_module_hook
if isinstance(obj, self.__class__)
else self._trainer._call_lightning_datamodule_hook
)
if isinstance(obj, self.__class__):
trainer_method = self._trainer._call_lightning_module_hook
else:
trainer_method = self._trainer._call_lightning_datamodule_hook
return trainer_method(hook_name, *args)
else:
hook = getattr(self, hook_name)
@ -312,7 +316,7 @@ class LightningModule(
batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
return batch
def print(self, *args, **kwargs) -> None:
def print(self, *args: Any, **kwargs: Any) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
@ -463,7 +467,7 @@ class LightningModule(
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
reduce_fx=reduce_fx, # type: ignore[arg-type]
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
batch_size=batch_size,
@ -578,7 +582,9 @@ class LightningModule(
"""
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=False, logger=True)
def all_gather(self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False):
def all_gather(
self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
) -> Union[Tensor, Dict, List, Tuple]:
r"""
Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation
accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several
@ -598,7 +604,7 @@ class LightningModule(
data = convert_to_tensors(data, device=self.device)
return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
def forward(self, *args, **kwargs) -> Any:
def forward(self, *args: Any, **kwargs: Any) -> Any:
r"""
Same as :meth:`torch.nn.Module.forward()`.
@ -611,7 +617,7 @@ class LightningModule(
"""
return super().forward(*args, **kwargs)
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
r"""
Here you compute and return the training loss and some additional metrics for e.g.
the progress bar or logger.
@ -769,7 +775,7 @@ class LightningModule(
...
"""
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
r"""
Operates on a single batch of data from the validation set.
In this step you'd might generate examples or calculate anything of interest like accuracy.
@ -858,7 +864,7 @@ class LightningModule(
the model goes back to training mode and gradients are enabled.
"""
def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def validation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""Use this when validating with dp because :meth:`validation_step` will operate on only part of the batch.
However, this is still optional and only needed for things like softmax or NCE loss.
@ -955,7 +961,7 @@ class LightningModule(
self.log("final_metric", final_value)
"""
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
r"""
Operates on a single batch of data from the test set.
In this step you'd normally generate examples or calculate anything of interest
@ -1035,7 +1041,7 @@ class LightningModule(
to training mode and gradients are enabled.
"""
def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def test_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""Use this when testing with DP because :meth:`test_step` will operate on only part of the batch. However,
this is still optional and only needed for things like softmax or NCE loss.
@ -1200,7 +1206,7 @@ class LightningModule(
"""
return []
def configure_optimizers(self):
def configure_optimizers(self) -> Any:
r"""
Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
@ -1374,7 +1380,7 @@ class LightningModule(
"""
rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
def manual_backward(self, loss: Tensor, *args, **kwargs) -> None:
def manual_backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
"""Call this directly from your :meth:`training_step` when doing optimizations manually. By using this,
Lightning can ensure that all the proper scaling gets applied when using mixed precision.
@ -1399,7 +1405,7 @@ class LightningModule(
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)
def backward(
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
) -> None:
"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
own implementation if you need to.
@ -1442,7 +1448,7 @@ class LightningModule(
# Then iterate over the current optimizer's parameters and set its `requires_grad`
# properties accordingly
for group in optimizer.param_groups:
for group in optimizer.param_groups: # type: ignore[union-attr]
for param in group["params"]:
param.requires_grad = param_requires_grad_state[param]
self._param_requires_grad_state = param_requires_grad_state
@ -1469,7 +1475,7 @@ class LightningModule(
optimizer: Optimizer,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
):
) -> None:
"""Handles gradient clipping internally.
Note:
@ -1523,7 +1529,7 @@ class LightningModule(
optimizer_idx: int,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
):
) -> None:
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
Args:
@ -1584,7 +1590,7 @@ class LightningModule(
"""
if metric is None:
scheduler.step()
scheduler.step() # type: ignore[call-arg]
else:
scheduler.step(metric)
@ -1672,7 +1678,7 @@ class LightningModule(
"""
optimizer.step(closure=optimizer_closure)
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Override this method to change the default behaviour of ``optimizer.zero_grad()``.
Args:
@ -1741,12 +1747,11 @@ class LightningModule(
for t in range(0, time_dims[0], split_size):
batch_split = []
for i, x in enumerate(batch):
split_x: Union[Tensor, List[Tensor]]
if isinstance(x, Tensor):
split_x = x[:, t : t + split_size]
elif isinstance(x, collections.abc.Sequence):
split_x = [None] * len(x)
for batch_idx in range(len(x)):
split_x[batch_idx] = x[batch_idx][t : t + split_size]
elif isinstance(x, collections.Sequence):
split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))]
batch_split.append(split_x)
@ -1782,7 +1787,7 @@ class LightningModule(
self.train()
def _verify_is_manual_optimization(self, fn_name):
def _verify_is_manual_optimization(self, fn_name: str) -> None:
if self.automatic_optimization:
raise MisconfigurationException(
f"to use {fn_name}, please disable automatic optimization:"
@ -1790,7 +1795,7 @@ class LightningModule(
)
@torch.no_grad()
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs):
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
"""Saves the model in ONNX format.
Args:
@ -1829,7 +1834,7 @@ class LightningModule(
if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs:
self.eval()
if isinstance(input_sample, Tuple):
if isinstance(input_sample, tuple):
kwargs["example_outputs"] = self(*input_sample)
else:
kwargs["example_outputs"] = self(input_sample)
@ -1843,7 +1848,7 @@ class LightningModule(
file_path: Optional[Union[str, Path]] = None,
method: Optional[str] = "script",
example_inputs: Optional[Any] = None,
**kwargs,
**kwargs: Any,
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
"""By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing,
please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is
@ -1953,7 +1958,7 @@ class LightningModule(
self._use_amp = use_amp
@contextmanager
def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:
def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]:
self._should_prevent_trainer_and_dataloaders_deepcopy = True
yield
self._should_prevent_trainer_and_dataloaders_deepcopy = False
@ -1988,4 +1993,6 @@ class LightningModule(
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
else:
# We need to make sure the self inside the method is a weakref proxy
self.__class__._register_load_state_dict_pre_hook(weakref.proxy(self), pre_load_state_dict_hook, True)
self.__class__._register_load_state_dict_pre_hook(
weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type]
)

View File

@ -13,7 +13,7 @@
# limitations under the License.
import numbers
import warnings
from typing import Any, cast, Optional, Union
from typing import Any, Optional, Union
import torch
from torch import Tensor
@ -77,7 +77,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
output = super().forward(*inputs, **kwargs)
def output_transform(data: Any) -> Any:
device = cast(torch.device, self.lightning_module.device)
device = self.lightning_module.device
data = python_scalar_to_tensor(data, device)
data = unsqueeze_scalar_tensor(data)
return data

View File

@ -168,6 +168,7 @@ class DistributedDataParallel(Protocol):
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
@dataclass