Fix mypy typing for `utilities.cloud_io.py` (#8671)

Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
Daniel Stancl 2021-08-03 11:56:28 +02:00 committed by GitHub
parent 8274183bf2
commit 08ac079c2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 5 deletions

View File

@ -64,6 +64,7 @@ module = [
"pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",
"pytorch_lightning.utilities.device_dtype_mixin",
"pytorch_lightning.utilities.device_parser",
"pytorch_lightning.utilities.parsing",

View File

@ -14,15 +14,20 @@
import io
from pathlib import Path
from typing import IO, Union
from typing import Any, Callable, Dict, IO, Optional, Union
import fsspec
import torch
from fsspec.implementations.local import LocalFileSystem
from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem
from packaging.version import Version
def load(path_or_url: Union[str, IO, Path], map_location=None):
def load(
path_or_url: Union[str, IO, Path],
map_location: Optional[
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
] = None,
) -> Any:
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similiar
return torch.load(path_or_url, map_location=map_location)
@ -33,7 +38,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None):
return torch.load(f, map_location=map_location)
def get_filesystem(path: Union[str, Path]):
def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem:
path = str(path)
if "://" in path:
# use the fileystem from the protocol specified
@ -42,7 +47,7 @@ def get_filesystem(path: Union[str, Path]):
return LocalFileSystem()
def atomic_save(checkpoint, filepath: str):
def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
Args: