Run mypy with PyTorch 1.12 (#14044)

This commit is contained in:
Carlos Mocholí 2022-08-08 10:06:41 +02:00 committed by GitHub
parent 5c05719f27
commit 76836a33cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 10 additions and 10 deletions

View File

@ -32,7 +32,7 @@ jobs:
- name: Install dependencies
run: |
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt
# todo: adjust requirements for both code-bases
pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html

View File

@ -52,7 +52,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.module",
"pytorch_lightning.core.saving",
"pytorch_lightning.demos.boring_classes",

View File

@ -23,7 +23,7 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
else:
MixedPrecision = None
MixedPrecision = None # type: ignore[misc,assignment]
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):

View File

@ -51,7 +51,7 @@ if _TORCH_GREATER_EQUAL_1_12:
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
MixedPrecision = None
MixedPrecision = None # type: ignore[misc,assignment]
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]

View File

@ -144,7 +144,7 @@ class _MultiProcessingLauncher(_Launcher):
# load last weights
if worker_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
trainer.lightning_module.load_state_dict(ckpt)
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)
trainer.state = worker_output.trainer_state

View File

@ -22,14 +22,12 @@ import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import AbstractFileSystem
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.types import _DEVICE, _PATH
def load(
path_or_url: Union[IO, _PATH],
map_location: Optional[
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
] = None,
map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None,
) -> Any:
"""Loads a checkpoint.
@ -41,7 +39,10 @@ def load(
# any sort of BytesIO or similar
return torch.load(path_or_url, map_location=map_location)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location)
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location)