From 090bbc8605ed6aca80a509185ff341b5895275db Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 19 Aug 2022 07:03:43 -0400 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning.core.module.py` (#13603) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli Co-authored-by: Carlos Mocholí --- pyproject.toml | 1 - .../core/mixins/device_dtype_mixin.py | 2 +- src/pytorch_lightning/core/module.py | 105 ++++++++++-------- .../overrides/data_parallel.py | 4 +- src/pytorch_lightning/utilities/types.py | 1 + 5 files changed, 60 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9f7cc28d0b..45f65b4c44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.core.datamodule", - "pytorch_lightning.core.module", "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.profilers.base", diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index 2916d8b07c..5086583d8e 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -37,7 +37,7 @@ class DeviceDtypeModuleMixin(Module): raise RuntimeError("Cannot set the dtype explicitly. Please use module.to(new_dtype).") @property - def device(self) -> Union[str, torch.device]: + def device(self) -> torch.device: device = self._device # make this more explicit to always include the index diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index e02552547d..0926cc52ec 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -22,7 +22,7 @@ import warnings import weakref from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, overload, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -47,12 +47,20 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13 from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import _METRIC_COLLECTION, EPOCH_OUTPUT, LRSchedulerTypeUnion, STEP_OUTPUT +from pytorch_lightning.utilities.types import ( + _METRIC_COLLECTION, + EPOCH_OUTPUT, + LRSchedulerPLType, + LRSchedulerTypeUnion, + STEP_OUTPUT, +) from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() log = logging.getLogger(__name__) +MODULE_OPTIMIZERS = Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]] + class LightningModule( DeviceDtypeModuleMixin, @@ -104,7 +112,7 @@ class LightningModule( self._current_fx_name: Optional[str] = None self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 - self._param_requires_grad_state = {} + self._param_requires_grad_state: Dict[str, bool] = {} self._metric_attributes: Optional[Dict[int, str]] = None self._should_prevent_trainer_and_dataloaders_deepcopy: bool = False # TODO: remove in 1.8 @@ -121,14 +129,10 @@ class LightningModule( ... @overload - def optimizers( - self, use_pl_optimizer: bool - ) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]: + def optimizers(self, use_pl_optimizer: bool) -> MODULE_OPTIMIZERS: ... - def optimizers( - self, use_pl_optimizer: bool = True - ) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]: + def optimizers(self, use_pl_optimizer: bool = True) -> MODULE_OPTIMIZERS: """Returns the optimizer(s) that are being used during training. Useful for manual optimization. Args: @@ -140,7 +144,7 @@ class LightningModule( A single optimizer, or a list of optimizers in case multiple ones are present. """ if use_pl_optimizer: - opts = list(self.trainer.strategy._lightning_optimizers.values()) + opts: MODULE_OPTIMIZERS = list(self.trainer.strategy._lightning_optimizers.values()) else: opts = self.trainer.optimizers @@ -150,7 +154,7 @@ class LightningModule( # multiple opts return opts - def lr_schedulers(self) -> Optional[Union[LRSchedulerTypeUnion, List[LRSchedulerTypeUnion]]]: + def lr_schedulers(self) -> Union[None, List[LRSchedulerPLType], LRSchedulerPLType]: """Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. @@ -162,7 +166,7 @@ class LightningModule( return None # ignore other keys "interval", "frequency", etc. - lr_schedulers = [config.scheduler for config in self.trainer.lr_scheduler_configs] + lr_schedulers: List[LRSchedulerPLType] = [config.scheduler for config in self.trainer.lr_scheduler_configs] # single scheduler if len(lr_schedulers) == 1: @@ -175,13 +179,13 @@ class LightningModule( def trainer(self) -> "pl.Trainer": if not self._running_torchscript and self._trainer is None: raise RuntimeError(f"{self.__class__.__qualname__} is not attached to a `Trainer`.") - return self._trainer + return self._trainer # type: ignore[return-value] @trainer.setter def trainer(self, trainer: Optional["pl.Trainer"]) -> None: for v in self.children(): if isinstance(v, LightningModule): - v.trainer = trainer + v.trainer = trainer # type: ignore[assignment] if trainer is not None and not isinstance(trainer, weakref.ProxyTypes): trainer = weakref.proxy(trainer) self._trainer = trainer @@ -228,7 +232,7 @@ class LightningModule( return self.trainer.local_rank if self._trainer else 0 @property - def on_gpu(self): + def on_gpu(self) -> bool: """Returns ``True`` if this model is currently located on a GPU. Useful to set flags around the LightningModule for different CPU vs GPU behavior. @@ -264,7 +268,7 @@ class LightningModule( # this should match the implementation of `trainer.logger` # we don't reuse it so we can properly set the deprecation stacklevel if self._trainer is None: - return + return None loggers = self.trainer.loggers if len(loggers) == 0: return None @@ -287,15 +291,15 @@ class LightningModule( """Reference to the list of loggers in the Trainer.""" return self.trainer.loggers if self._trainer else [] - def _call_batch_hook(self, hook_name, *args) -> Any: + def _call_batch_hook(self, hook_name: str, *args: Any) -> Any: if self._trainer: datahook_selector = self._trainer._data_connector._datahook_selector obj = datahook_selector.get_instance(hook_name) - trainer_method = ( - self._trainer._call_lightning_module_hook - if isinstance(obj, self.__class__) - else self._trainer._call_lightning_datamodule_hook - ) + if isinstance(obj, self.__class__): + trainer_method = self._trainer._call_lightning_module_hook + else: + trainer_method = self._trainer._call_lightning_datamodule_hook + return trainer_method(hook_name, *args) else: hook = getattr(self, hook_name) @@ -312,7 +316,7 @@ class LightningModule( batch = self._call_batch_hook("on_after_batch_transfer", batch, dataloader_idx) return batch - def print(self, *args, **kwargs) -> None: + def print(self, *args: Any, **kwargs: Any) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. @@ -463,7 +467,7 @@ class LightningModule( logger=logger, on_step=on_step, on_epoch=on_epoch, - reduce_fx=reduce_fx, + reduce_fx=reduce_fx, # type: ignore[arg-type] enable_graph=enable_graph, add_dataloader_idx=add_dataloader_idx, batch_size=batch_size, @@ -578,7 +582,9 @@ class LightningModule( """ self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=False, logger=True) - def all_gather(self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False): + def all_gather( + self, data: Union[Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False + ) -> Union[Tensor, Dict, List, Tuple]: r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ``all_gather`` operation accelerator agnostic. ``all_gather`` is a function provided by accelerators to gather a tensor from several @@ -598,7 +604,7 @@ class LightningModule( data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads) - def forward(self, *args, **kwargs) -> Any: + def forward(self, *args: Any, **kwargs: Any) -> Any: r""" Same as :meth:`torch.nn.Module.forward()`. @@ -611,7 +617,7 @@ class LightningModule( """ return super().forward(*args, **kwargs) - def training_step(self, *args, **kwargs) -> STEP_OUTPUT: + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: r""" Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. @@ -769,7 +775,7 @@ class LightningModule( ... """ - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the validation set. In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -858,7 +864,7 @@ class LightningModule( the model goes back to training mode and gradients are enabled. """ - def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: """Use this when validating with dp because :meth:`validation_step` will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss. @@ -955,7 +961,7 @@ class LightningModule( self.log("final_metric", final_value) """ - def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: r""" Operates on a single batch of data from the test set. In this step you'd normally generate examples or calculate anything of interest @@ -1035,7 +1041,7 @@ class LightningModule( to training mode and gradients are enabled. """ - def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: """Use this when testing with DP because :meth:`test_step` will operate on only part of the batch. However, this is still optional and only needed for things like softmax or NCE loss. @@ -1200,7 +1206,7 @@ class LightningModule( """ return [] - def configure_optimizers(self): + def configure_optimizers(self) -> Any: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. @@ -1374,7 +1380,7 @@ class LightningModule( """ rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer") - def manual_backward(self, loss: Tensor, *args, **kwargs) -> None: + def manual_backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None: """Call this directly from your :meth:`training_step` when doing optimizations manually. By using this, Lightning can ensure that all the proper scaling gets applied when using mixed precision. @@ -1399,7 +1405,7 @@ class LightningModule( self.trainer.strategy.backward(loss, None, None, *args, **kwargs) def backward( - self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs + self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args: Any, **kwargs: Any ) -> None: """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own implementation if you need to. @@ -1442,7 +1448,7 @@ class LightningModule( # Then iterate over the current optimizer's parameters and set its `requires_grad` # properties accordingly - for group in optimizer.param_groups: + for group in optimizer.param_groups: # type: ignore[union-attr] for param in group["params"]: param.requires_grad = param_requires_grad_state[param] self._param_requires_grad_state = param_requires_grad_state @@ -1469,7 +1475,7 @@ class LightningModule( optimizer: Optimizer, gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, - ): + ) -> None: """Handles gradient clipping internally. Note: @@ -1523,7 +1529,7 @@ class LightningModule( optimizer_idx: int, gradient_clip_val: Optional[Union[int, float]] = None, gradient_clip_algorithm: Optional[str] = None, - ): + ) -> None: """Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`. Args: @@ -1584,7 +1590,7 @@ class LightningModule( """ if metric is None: - scheduler.step() + scheduler.step() # type: ignore[call-arg] else: scheduler.step(metric) @@ -1672,7 +1678,7 @@ class LightningModule( """ optimizer.step(closure=optimizer_closure) - def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int) -> None: """Override this method to change the default behaviour of ``optimizer.zero_grad()``. Args: @@ -1741,12 +1747,11 @@ class LightningModule( for t in range(0, time_dims[0], split_size): batch_split = [] for i, x in enumerate(batch): + split_x: Union[Tensor, List[Tensor]] if isinstance(x, Tensor): split_x = x[:, t : t + split_size] - elif isinstance(x, collections.abc.Sequence): - split_x = [None] * len(x) - for batch_idx in range(len(x)): - split_x[batch_idx] = x[batch_idx][t : t + split_size] + elif isinstance(x, collections.Sequence): + split_x = [x[batch_idx][t : t + split_size] for batch_idx in range(len(x))] batch_split.append(split_x) @@ -1782,7 +1787,7 @@ class LightningModule( self.train() - def _verify_is_manual_optimization(self, fn_name): + def _verify_is_manual_optimization(self, fn_name: str) -> None: if self.automatic_optimization: raise MisconfigurationException( f"to use {fn_name}, please disable automatic optimization:" @@ -1790,7 +1795,7 @@ class LightningModule( ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs): + def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None: """Saves the model in ONNX format. Args: @@ -1829,7 +1834,7 @@ class LightningModule( if not _TORCH_GREATER_EQUAL_1_10 and "example_outputs" not in kwargs: self.eval() - if isinstance(input_sample, Tuple): + if isinstance(input_sample, tuple): kwargs["example_outputs"] = self(*input_sample) else: kwargs["example_outputs"] = self(input_sample) @@ -1843,7 +1848,7 @@ class LightningModule( file_path: Optional[Union[str, Path]] = None, method: Optional[str] = "script", example_inputs: Optional[Any] = None, - **kwargs, + **kwargs: Any, ) -> Union[ScriptModule, Dict[str, ScriptModule]]: """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing, please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is @@ -1953,7 +1958,7 @@ class LightningModule( self._use_amp = use_amp @contextmanager - def _prevent_trainer_and_dataloaders_deepcopy(self) -> None: + def _prevent_trainer_and_dataloaders_deepcopy(self) -> Generator[None, None, None]: self._should_prevent_trainer_and_dataloaders_deepcopy = True yield self._should_prevent_trainer_and_dataloaders_deepcopy = False @@ -1988,4 +1993,6 @@ class LightningModule( self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) else: # We need to make sure the self inside the method is a weakref proxy - self.__class__._register_load_state_dict_pre_hook(weakref.proxy(self), pre_load_state_dict_hook, True) + self.__class__._register_load_state_dict_pre_hook( + weakref.proxy(self), pre_load_state_dict_hook, True # type: ignore[arg-type] + ) diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 98d23cee39..b296d1d869 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -13,7 +13,7 @@ # limitations under the License. import numbers import warnings -from typing import Any, cast, Optional, Union +from typing import Any, Optional, Union import torch from torch import Tensor @@ -77,7 +77,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): output = super().forward(*inputs, **kwargs) def output_transform(data: Any) -> Any: - device = cast(torch.device, self.lightning_module.device) + device = self.lightning_module.device data = python_scalar_to_tensor(data, device) data = unsqueeze_scalar_tensor(data) return data diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 18e2db6feb..9f2db64226 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -168,6 +168,7 @@ class DistributedDataParallel(Protocol): LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]] +LRSchedulerPLType = Union[_LRScheduler, ReduceLROnPlateau] @dataclass