Improve typing for Lite (#10743)
* improve typing in pytorch_lightning/lite * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * include lite again Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e94aff1c5b
commit
81a0a44d8f
|
@ -36,7 +36,7 @@ disable_error_code = "attr-defined"
|
|||
# style choices
|
||||
warn_no_return = "False"
|
||||
|
||||
# Changes mypy default to ignore all errors
|
||||
# Ignore mypy errors for these files
|
||||
# TODO: the goal is for this to be empty
|
||||
[[tool.mypy.overrides]]
|
||||
# the list can be generated with:
|
||||
|
@ -63,8 +63,6 @@ module = [
|
|||
"pytorch_lightning.core.mixins.hparams_mixin",
|
||||
"pytorch_lightning.core.saving",
|
||||
"pytorch_lightning.distributed.dist",
|
||||
"pytorch_lightning.lite.lite",
|
||||
"pytorch_lightning.lite.wrappers",
|
||||
"pytorch_lightning.loggers.base",
|
||||
"pytorch_lightning.loggers.comet",
|
||||
"pytorch_lightning.loggers.csv_logs",
|
||||
|
|
|
@ -16,7 +16,7 @@ from abc import ABC, abstractmethod
|
|||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -201,7 +201,7 @@ class LightningLite(ABC):
|
|||
for dataloader in dataloaders
|
||||
]
|
||||
dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders
|
||||
return dataloaders
|
||||
return dataloaders # type: ignore[return-value]
|
||||
|
||||
def _setup_dataloader(
|
||||
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
|
||||
|
@ -284,6 +284,18 @@ class LightningLite(ABC):
|
|||
with self._precision_plugin.forward_context():
|
||||
yield
|
||||
|
||||
@overload
|
||||
def to_device(self, obj: nn.Module) -> nn.Module:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to_device(self, obj: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to_device(self, obj: Any) -> Any:
|
||||
...
|
||||
|
||||
def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]:
|
||||
"""Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already
|
||||
on that device.
|
||||
|
|
|
@ -131,6 +131,7 @@ class _LiteDataLoader:
|
|||
iterator = iter(self._dataloader)
|
||||
if self._device is None:
|
||||
yield from iterator
|
||||
return
|
||||
|
||||
for item in iterator:
|
||||
yield move_data_to_device(item, self._device)
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -241,7 +241,7 @@ class TrainingTypePlugin(ABC):
|
|||
def test_step_end(self, output):
|
||||
return output
|
||||
|
||||
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
|
||||
def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
|
||||
"""Wraps the dataloader if necessary.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue