Improve typing for Lite (#10743)

* improve typing in pytorch_lightning/lite

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* include lite again

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-26 21:14:11 +01:00 committed by GitHub
parent e94aff1c5b
commit 81a0a44d8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 7 deletions

View File

@ -36,7 +36,7 @@ disable_error_code = "attr-defined"
# style choices
warn_no_return = "False"
# Changes mypy default to ignore all errors
# Ignore mypy errors for these files
# TODO: the goal is for this to be empty
[[tool.mypy.overrides]]
# the list can be generated with:
@ -63,8 +63,6 @@ module = [
"pytorch_lightning.core.mixins.hparams_mixin",
"pytorch_lightning.core.saving",
"pytorch_lightning.distributed.dist",
"pytorch_lightning.lite.lite",
"pytorch_lightning.lite.wrappers",
"pytorch_lightning.loggers.base",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.csv_logs",

View File

@ -16,7 +16,7 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union
import torch
import torch.nn as nn
@ -201,7 +201,7 @@ class LightningLite(ABC):
for dataloader in dataloaders
]
dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders
return dataloaders
return dataloaders # type: ignore[return-value]
def _setup_dataloader(
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
@ -284,6 +284,18 @@ class LightningLite(ABC):
with self._precision_plugin.forward_context():
yield
@overload
def to_device(self, obj: nn.Module) -> nn.Module:
...
@overload
def to_device(self, obj: Tensor) -> Tensor:
...
@overload
def to_device(self, obj: Any) -> Any:
...
def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already
on that device.

View File

@ -131,6 +131,7 @@ class _LiteDataLoader:
iterator = iter(self._dataloader)
if self._device is None:
yield from iterator
return
for item in iterator:
yield move_data_to_device(item, self._device)

View File

@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
import torch
from torch import Tensor
@ -241,7 +241,7 @@ class TrainingTypePlugin(ABC):
def test_step_end(self, output):
return output
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
"""Wraps the dataloader if necessary.
Args: