Fix mypy errors attributed to `pytorch_lightning.core.module.py` (#13603)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
c9b3cda0e0
commit
090bbc8605
|
@ -52,7 +52,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.callbacks.quantization",
|
||||
"pytorch_lightning.core.datamodule",
|
||||
"pytorch_lightning.core.module",
|
||||
"pytorch_lightning.demos.boring_classes",
|
||||
"pytorch_lightning.demos.mnist_datamodule",
|
||||
"pytorch_lightning.profilers.base",
|
||||
|
|
|
@ -37,7 +37,7 @@ class DeviceDtypeModuleMixin(Module):
|
|||
raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).")
|
||||
|
||||
@property
|
||||
def device(self) -> Union[str, torch.device]:
|
||||
def device(self) -> torch.device:
|
||||
device = self._device
|
||||
|
||||
# make this more explicit to always include the index
|
||||
|
|
|
@ -22,7 +22,7 @@ import warnings
|
|||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, overload, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import ScriptModule, Tensor
|
||||
|
@ -47,12 +47,20 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import (
|
||||
_METRIC_COLLECTION,
|
||||
EPOCH_OUTPUT,
|
||||
LRSchedulerPLType,
|
||||
LRSchedulerTypeUnion,
|
||||
STEP_OUTPUT,
|
||||
)
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
warning_cache = WarningCache()
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
MODULE_OPTIMIZERS = Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]
|
||||
|
||||
|
||||
class LightningModule(
|
||||
DeviceDtypeModuleMixin,
|
||||
|
@ -104,7 +112,7 @@ class LightningModule(
|
|||
self._current_fx_name: Optional[str] = None
|
||||
self._automatic_optimization: bool = True
|
||||
self._truncated_bptt_steps: int = 0
|
||||
self._param_requires_grad_state = {}
|
||||
self._param_requires_grad_state: Dict[str, bool] = {}
|
||||
self._metric_attributes: Optional[Dict[int, str]] = None
|
||||
self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False
|
||||
# TODO: remove in 1.8
|
||||
|
@ -121,14 +129,10 @@ class LightningModule(
|
|||
...
|
||||
|
||||
@overload
|
||||
def optimizers(
|
||||
self, use_pl_optimizer: bool
|
||||
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
|
||||
def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS:
|
||||
...
|
||||
|
||||
def optimizers(
|
||||
self, use_pl_optimizer: bool = True
|
||||
) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]:
|
||||
def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS:
|
||||
"""Returns the optimizer(s) that are being used during training. Useful for manual optimization.
|
||||
|
||||
Args:
|
||||
|
@ -140,7 +144,7 @@ class LightningModule(
|
|||
A single optimizer, or a list of optimizers in case multiple ones are present.
|
||||
"""
|
||||
if use_pl_optimizer:
|
||||
opts = list(self.trainer.strategy._lightning_optimizers.values())
|
||||
opts: MODULE_OPTIMIZERS = list(self.trainer.strategy._lightning_optimizers.values())
|
||||
else:
|
||||
opts = self.trainer.optimizers
|
||||
|
||||
|
@ -150,7 +154,7 @@ class LightningModule(
|
|||
# multiple opts
|
||||
return opts
|
||||
|
||||
def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRSchedulerTypeUnion]]]:
|
||||
def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]:
|
||||
"""Returns the learning rate scheduler(s) that are being used during training. Useful for manual
|
||||
optimization.
|
||||
|
||||
|
@ -162,7 +166,7 @@ class LightningModule(
|
|||
return None
|
||||
|
||||
# ignore other keys "interval", "frequency", etc.
|
||||
lr_schedulers = [config.scheduler for config in self.trainer.lr_scheduler_configs]
|
||||
lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs]
|
||||
|
||||
# single scheduler
|
||||
if len(lr_schedulers) == 1:
|
||||
|
@ -175,13 +179,13 @@ class LightningModule(
|
|||
def trainer(self) -> "pl.Trainer":
|
||||
if not self._running_torchscript and self._trainer is None:
|
||||
raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.")
|
||||
return self._trainer
|
||||
return self._trainer # type: ignore[return-value]
|
||||
|
||||
@trainer.setter
|
||||
def trainer(self, trainer: Optional["pl.Trainer"]) -> None:
|
||||
for v in self.children():
|
||||
if isinstance(v, LightningModule):
|
||||
v.trainer = trainer
|
||||
v.trainer = trainer # type: ignore[assignment]
|
||||
if trainer is not None and not isinstance(trainer, weakref.ProxyTypes):
|
||||
trainer = weakref.proxy(trainer)
|
||||
self._trainer = trainer
|
||||
|
@ -228,7 +232,7 @@ class LightningModule(
|
|||
return self.trainer.local_rank if self._trainer else 0
|
||||
|
||||
@property
|
||||
def on_gpu(self):
|
||||
def on_gpu(self) -> bool:
|
||||
"""Returns ``True`` if this model is currently located on a GPU.
|
||||
|
||||
Useful to set flags around the LightningModule for different CPU vs GPU behavior.
|
||||
|
@ -264,7 +268,7 @@ class LightningModule(
|
|||
# this should match the implementation of `trainer.logger`
|
||||
# we don't reuse it so we can properly set the deprecation stacklevel
|
||||
if self._trainer is None:
|
||||
return
|
||||
return None
|
||||
loggers = self.trainer.loggers
|
||||
if len(loggers) == 0:
|
||||
return None
|
||||
|
@ -287,15 +291,15 @@ class LightningModule(
|
|||
"""Reference to the list of loggers in the Trainer."""
|
||||
return self.trainer.loggers if self._trainer else []
|
||||
|
||||
def _call_batch_hook(self, hook_name, *args) -> Any:
|
||||
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
|
||||
if self._trainer:
|
||||
datahook_selector = self._trainer._data_connector._datahook_selector
|
||||
obj = datahook_selector.get_instance(hook_name)
|
||||
trainer_method = (
|
||||
self._trainer._call_lightning_module_hook
|
||||
if isinstance(obj, self.__class__)
|
||||
else self._trainer._call_lightning_datamodule_hook
|
||||
)
|
||||
if isinstance(obj, self.__class__):
|
||||
trainer_method = self._trainer._call_lightning_module_hook
|
||||
else:
|
||||
trainer_method = self._trainer._call_lightning_datamodule_hook
|
||||
|
||||
return trainer_method(hook_name, *args)
|
||||
else:
|
||||
hook = getattr(self, hook_name)
|
||||
|
@ -312,7 +316,7 @@ class LightningModule(
|
|||
batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx)
|
||||
return batch
|
||||
|
||||
def print(self, *args, **kwargs) -> None:
|
||||
def print(self, *args: Any, **kwargs: Any) -> None:
|
||||
r"""
|
||||
Prints only from process 0. Use this in any distributed mode to log only once.
|
||||
|
||||
|
@ -463,7 +467,7 @@ class LightningModule(
|
|||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
reduce_fx=reduce_fx, # type: ignore[arg-type]
|
||||
enable_graph=enable_graph,
|
||||
add_dataloader_idx=add_dataloader_idx,
|
||||
batch_size=batch_size,
|
||||
|
@ -578,7 +582,9 @@ class LightningModule(
|
|||
"""
|
||||
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
|
||||
def all_gather(self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False):
|
||||
def all_gather(
|
||||
self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
|
||||
) -> Union[Tensor, Dict, List, Tuple]:
|
||||
r"""
|
||||
Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation
|
||||
accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several
|
||||
|
@ -598,7 +604,7 @@ class LightningModule(
|
|||
data = convert_to_tensors(data, device=self.device)
|
||||
return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""
|
||||
Same as :meth:`torch.nn.Module.forward()`.
|
||||
|
||||
|
@ -611,7 +617,7 @@ class LightningModule(
|
|||
"""
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
|
||||
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
||||
r"""
|
||||
Here you compute and return the training loss and some additional metrics for e.g.
|
||||
the progress bar or logger.
|
||||
|
@ -769,7 +775,7 @@ class LightningModule(
|
|||
...
|
||||
"""
|
||||
|
||||
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
|
||||
r"""
|
||||
Operates on a single batch of data from the validation set.
|
||||
In this step you'd might generate examples or calculate anything of interest like accuracy.
|
||||
|
@ -858,7 +864,7 @@ class LightningModule(
|
|||
the model goes back to training mode and gradients are enabled.
|
||||
"""
|
||||
|
||||
def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
def validation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
|
||||
"""Use this when validating with dp because :meth:`validation_step` will operate on only part of the batch.
|
||||
However, this is still optional and only needed for things like softmax or NCE loss.
|
||||
|
||||
|
@ -955,7 +961,7 @@ class LightningModule(
|
|||
self.log("final_metric", final_value)
|
||||
"""
|
||||
|
||||
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
|
||||
r"""
|
||||
Operates on a single batch of data from the test set.
|
||||
In this step you'd normally generate examples or calculate anything of interest
|
||||
|
@ -1035,7 +1041,7 @@ class LightningModule(
|
|||
to training mode and gradients are enabled.
|
||||
"""
|
||||
|
||||
def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
|
||||
def test_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
|
||||
"""Use this when testing with DP because :meth:`test_step` will operate on only part of the batch. However,
|
||||
this is still optional and only needed for things like softmax or NCE loss.
|
||||
|
||||
|
@ -1200,7 +1206,7 @@ class LightningModule(
|
|||
"""
|
||||
return []
|
||||
|
||||
def configure_optimizers(self):
|
||||
def configure_optimizers(self) -> Any:
|
||||
r"""
|
||||
Choose what optimizers and learning-rate schedulers to use in your optimization.
|
||||
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
||||
|
@ -1374,7 +1380,7 @@ class LightningModule(
|
|||
"""
|
||||
rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer")
|
||||
|
||||
def manual_backward(self, loss: Tensor, *args, **kwargs) -> None:
|
||||
def manual_backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
|
||||
"""Call this directly from your :meth:`training_step` when doing optimizations manually. By using this,
|
||||
Lightning can ensure that all the proper scaling gets applied when using mixed precision.
|
||||
|
||||
|
@ -1399,7 +1405,7 @@ class LightningModule(
|
|||
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)
|
||||
|
||||
def backward(
|
||||
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
|
||||
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your
|
||||
own implementation if you need to.
|
||||
|
@ -1442,7 +1448,7 @@ class LightningModule(
|
|||
|
||||
# Then iterate over the current optimizer's parameters and set its `requires_grad`
|
||||
# properties accordingly
|
||||
for group in optimizer.param_groups:
|
||||
for group in optimizer.param_groups: # type: ignore[union-attr]
|
||||
for param in group["params"]:
|
||||
param.requires_grad = param_requires_grad_state[param]
|
||||
self._param_requires_grad_state = param_requires_grad_state
|
||||
|
@ -1469,7 +1475,7 @@ class LightningModule(
|
|||
optimizer: Optimizer,
|
||||
gradient_clip_val: Optional[Union[int, float]] = None,
|
||||
gradient_clip_algorithm: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Handles gradient clipping internally.
|
||||
|
||||
Note:
|
||||
|
@ -1523,7 +1529,7 @@ class LightningModule(
|
|||
optimizer_idx: int,
|
||||
gradient_clip_val: Optional[Union[int, float]] = None,
|
||||
gradient_clip_algorithm: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
|
||||
|
||||
Args:
|
||||
|
@ -1584,7 +1590,7 @@ class LightningModule(
|
|||
|
||||
"""
|
||||
if metric is None:
|
||||
scheduler.step()
|
||||
scheduler.step() # type: ignore[call-arg]
|
||||
else:
|
||||
scheduler.step(metric)
|
||||
|
||||
|
@ -1672,7 +1678,7 @@ class LightningModule(
|
|||
"""
|
||||
optimizer.step(closure=optimizer_closure)
|
||||
|
||||
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
|
||||
def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None:
|
||||
"""Override this method to change the default behaviour of ``optimizer.zero_grad()``.
|
||||
|
||||
Args:
|
||||
|
@ -1741,12 +1747,11 @@ class LightningModule(
|
|||
for t in range(0, time_dims[0], split_size):
|
||||
batch_split = []
|
||||
for i, x in enumerate(batch):
|
||||
split_x: Union[Tensor, List[Tensor]]
|
||||
if isinstance(x, Tensor):
|
||||
split_x = x[:, t : t + split_size]
|
||||
elif isinstance(x, collections.abc.Sequence):
|
||||
split_x = [None] * len(x)
|
||||
for batch_idx in range(len(x)):
|
||||
split_x[batch_idx] = x[batch_idx][t : t + split_size]
|
||||
elif isinstance(x, collections.Sequence):
|
||||
split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))]
|
||||
|
||||
batch_split.append(split_x)
|
||||
|
||||
|
@ -1782,7 +1787,7 @@ class LightningModule(
|
|||
|
||||
self.train()
|
||||
|
||||
def _verify_is_manual_optimization(self, fn_name):
|
||||
def _verify_is_manual_optimization(self, fn_name: str) -> None:
|
||||
if self.automatic_optimization:
|
||||
raise MisconfigurationException(
|
||||
f"to use {fn_name}, please disable automatic optimization:"
|
||||
|
@ -1790,7 +1795,7 @@ class LightningModule(
|
|||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs):
|
||||
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
|
||||
"""Saves the model in ONNX format.
|
||||
|
||||
Args:
|
||||
|
@ -1829,7 +1834,7 @@ class LightningModule(
|
|||
|
||||
if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs:
|
||||
self.eval()
|
||||
if isinstance(input_sample, Tuple):
|
||||
if isinstance(input_sample, tuple):
|
||||
kwargs["example_outputs"] = self(*input_sample)
|
||||
else:
|
||||
kwargs["example_outputs"] = self(input_sample)
|
||||
|
@ -1843,7 +1848,7 @@ class LightningModule(
|
|||
file_path: Optional[Union[str, Path]] = None,
|
||||
method: Optional[str] = "script",
|
||||
example_inputs: Optional[Any] = None,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> Union[ScriptModule, Dict[str, ScriptModule]]:
|
||||
"""By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing,
|
||||
please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is
|
||||
|
@ -1953,7 +1958,7 @@ class LightningModule(
|
|||
self._use_amp = use_amp
|
||||
|
||||
@contextmanager
|
||||
def _prevent_trainer_and_dataloaders_deepcopy(self) -> None:
|
||||
def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]:
|
||||
self._should_prevent_trainer_and_dataloaders_deepcopy = True
|
||||
yield
|
||||
self._should_prevent_trainer_and_dataloaders_deepcopy = False
|
||||
|
@ -1988,4 +1993,6 @@ class LightningModule(
|
|||
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
|
||||
else:
|
||||
# We need to make sure the self inside the method is a weakref proxy
|
||||
self.__class__._register_load_state_dict_pre_hook(weakref.proxy(self), pre_load_state_dict_hook, True)
|
||||
self.__class__._register_load_state_dict_pre_hook(
|
||||
weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type]
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import numbers
|
||||
import warnings
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -77,7 +77,7 @@ class LightningParallelModule(_LightningModuleWrapperBase):
|
|||
output = super().forward(*inputs, **kwargs)
|
||||
|
||||
def output_transform(data: Any) -> Any:
|
||||
device = cast(torch.device, self.lightning_module.device)
|
||||
device = self.lightning_module.device
|
||||
data = python_scalar_to_tensor(data, device)
|
||||
data = unsqueeze_scalar_tensor(data)
|
||||
return data
|
||||
|
|
|
@ -168,6 +168,7 @@ class DistributedDataParallel(Protocol):
|
|||
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
|
||||
LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
Loading…
Reference in New Issue