diff --git a/pyproject.toml b/pyproject.toml index 996607d0d9..bb3f093e1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ module = [ "pytorch_lightning.trainer.connectors.logger_connector", "pytorch_lightning.utilities.argparse", "pytorch_lightning.utilities.cli", + "pytorch_lightning.utilities.cloud_io", "pytorch_lightning.utilities.device_dtype_mixin", "pytorch_lightning.utilities.device_parser", "pytorch_lightning.utilities.parsing", diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 6bd6a172a7..9b40f6d69c 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -14,15 +14,20 @@ import io from pathlib import Path -from typing import IO, Union +from typing import Any, Callable, Dict, IO, Optional, Union import fsspec import torch -from fsspec.implementations.local import LocalFileSystem +from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem from packaging.version import Version -def load(path_or_url: Union[str, IO, Path], map_location=None): +def load( + path_or_url: Union[str, IO, Path], + map_location: Optional[ + Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] + ] = None, +) -> Any: if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similiar return torch.load(path_or_url, map_location=map_location) @@ -33,7 +38,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None): return torch.load(f, map_location=map_location) -def get_filesystem(path: Union[str, Path]): +def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem: path = str(path) if "://" in path: # use the fileystem from the protocol specified @@ -42,7 +47,7 @@ def get_filesystem(path: Union[str, Path]): return LocalFileSystem() -def atomic_save(checkpoint, filepath: str): +def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. Args: