From 76836a33cdfa63e2c85c6f4ea9b2a1f174c973e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 8 Aug 2022 10:06:41 +0200 Subject: [PATCH] Run mypy with PyTorch 1.12 (#14044) --- .github/workflows/code-checks.yml | 2 +- pyproject.toml | 1 - .../plugins/precision/fully_sharded_native_amp.py | 2 +- .../strategies/fully_sharded_native.py | 2 +- .../strategies/launchers/multiprocessing.py | 2 +- src/pytorch_lightning/utilities/cloud_io.py | 11 ++++++----- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index 7b5f3f2660..15bd5e9911 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -32,7 +32,7 @@ jobs: - name: Install dependencies run: | - pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt # todo: adjust requirements for both code-bases pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html diff --git a/pyproject.toml b/pyproject.toml index 5473e73c52..9b8400ba27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.core.datamodule", - "pytorch_lightning.core.decorators", "pytorch_lightning.core.module", "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 8c693f2975..60e53b880c 100644 --- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -23,7 +23,7 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision else: - MixedPrecision = None + MixedPrecision = None # type: ignore[misc,assignment] class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 4c351f26fa..d92931fb5c 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -51,7 +51,7 @@ if _TORCH_GREATER_EQUAL_1_12: ) from torch.distributed.fsdp.wrap import enable_wrap else: - MixedPrecision = None + MixedPrecision = None # type: ignore[misc,assignment] BackwardPrefetch = None # type: ignore[misc,assignment] CPUOffload = None # type: ignore[misc,assignment] diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 39bba092e9..2617e5fe27 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -144,7 +144,7 @@ class _MultiProcessingLauncher(_Launcher): # load last weights if worker_output.weights_path is not None: ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path) - trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type] + trainer.lightning_module.load_state_dict(ckpt) self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path) trainer.state = worker_output.trainer_state diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 81482a8ab2..ee3358be59 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -22,14 +22,12 @@ import torch from fsspec.core import url_to_fs from fsspec.implementations.local import AbstractFileSystem -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _DEVICE, _PATH def load( path_or_url: Union[IO, _PATH], - map_location: Optional[ - Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] - ] = None, + map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None, ) -> Any: """Loads a checkpoint. @@ -41,7 +39,10 @@ def load( # any sort of BytesIO or similar return torch.load(path_or_url, map_location=map_location) if str(path_or_url).startswith("http"): - return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location) + return torch.hub.load_state_dict_from_url( + str(path_or_url), + map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct + ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: return torch.load(f, map_location=map_location)