Run mypy with PyTorch 1.12 (#14044)
This commit is contained in:
parent
5c05719f27
commit
76836a33cd
|
@ -32,7 +32,7 @@ jobs:
|
|||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt
|
||||
# todo: adjust requirements for both code-bases
|
||||
pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
|
|
|
@ -52,7 +52,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.callbacks.quantization",
|
||||
"pytorch_lightning.core.datamodule",
|
||||
"pytorch_lightning.core.decorators",
|
||||
"pytorch_lightning.core.module",
|
||||
"pytorch_lightning.core.saving",
|
||||
"pytorch_lightning.demos.boring_classes",
|
||||
|
|
|
@ -23,7 +23,7 @@ from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
|||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
|
||||
else:
|
||||
MixedPrecision = None
|
||||
MixedPrecision = None # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
|
||||
|
|
|
@ -51,7 +51,7 @@ if _TORCH_GREATER_EQUAL_1_12:
|
|||
)
|
||||
from torch.distributed.fsdp.wrap import enable_wrap
|
||||
else:
|
||||
MixedPrecision = None
|
||||
MixedPrecision = None # type: ignore[misc,assignment]
|
||||
BackwardPrefetch = None # type: ignore[misc,assignment]
|
||||
CPUOffload = None # type: ignore[misc,assignment]
|
||||
|
||||
|
|
|
@ -144,7 +144,7 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
# load last weights
|
||||
if worker_output.weights_path is not None:
|
||||
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
|
||||
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
|
||||
trainer.lightning_module.load_state_dict(ckpt)
|
||||
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)
|
||||
|
||||
trainer.state = worker_output.trainer_state
|
||||
|
|
|
@ -22,14 +22,12 @@ import torch
|
|||
from fsspec.core import url_to_fs
|
||||
from fsspec.implementations.local import AbstractFileSystem
|
||||
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
from pytorch_lightning.utilities.types import _DEVICE, _PATH
|
||||
|
||||
|
||||
def load(
|
||||
path_or_url: Union[IO, _PATH],
|
||||
map_location: Optional[
|
||||
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
|
||||
] = None,
|
||||
map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None,
|
||||
) -> Any:
|
||||
"""Loads a checkpoint.
|
||||
|
||||
|
@ -41,7 +39,10 @@ def load(
|
|||
# any sort of BytesIO or similar
|
||||
return torch.load(path_or_url, map_location=map_location)
|
||||
if str(path_or_url).startswith("http"):
|
||||
return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location)
|
||||
return torch.hub.load_state_dict_from_url(
|
||||
str(path_or_url),
|
||||
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
|
||||
)
|
||||
fs = get_filesystem(path_or_url)
|
||||
with fs.open(path_or_url, "rb") as f:
|
||||
return torch.load(f, map_location=map_location)
|
||||
|
|
Loading…
Reference in New Issue