diff --git a/pyproject.toml b/pyproject.toml index 9d3e4fd80f..168e60e1e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ disable_error_code = "attr-defined" # style choices warn_no_return = "False" -# Changes mypy default to ignore all errors +# Ignore mypy errors for these files # TODO: the goal is for this to be empty [[tool.mypy.overrides]] # the list can be generated with: @@ -63,8 +63,6 @@ module = [ "pytorch_lightning.core.mixins.hparams_mixin", "pytorch_lightning.core.saving", "pytorch_lightning.distributed.dist", - "pytorch_lightning.lite.lite", - "pytorch_lightning.lite.wrappers", "pytorch_lightning.loggers.base", "pytorch_lightning.loggers.comet", "pytorch_lightning.loggers.csv_logs", diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 9073f5dd54..0d292dba54 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union import torch import torch.nn as nn @@ -201,7 +201,7 @@ class LightningLite(ABC): for dataloader in dataloaders ] dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders - return dataloaders + return dataloaders # type: ignore[return-value] def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True @@ -284,6 +284,18 @@ class LightningLite(ABC): with self._precision_plugin.forward_context(): yield + @overload + def to_device(self, obj: nn.Module) -> nn.Module: + ... + + @overload + def to_device(self, obj: Tensor) -> Tensor: + ... + + @overload + def to_device(self, obj: Any) -> Any: + ... + def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on that device. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 908ba06bdb..202404ef71 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -131,6 +131,7 @@ class _LiteDataLoader: iterator = iter(self._dataloader) if self._device is None: yield from iterator + return for item in iterator: yield move_data_to_device(item, self._device) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 7010c0e878..be51cc9f92 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor @@ -241,7 +241,7 @@ class TrainingTypePlugin(ABC): def test_step_end(self, output): return output - def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + def process_dataloader(self, dataloader: DataLoader) -> DataLoader: """Wraps the dataloader if necessary. Args: