Fix mypy typing for `utilities.cloud_io.py` (#8671)
Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
parent
8274183bf2
commit
08ac079c2f
|
@ -64,6 +64,7 @@ module = [
|
|||
"pytorch_lightning.trainer.connectors.logger_connector",
|
||||
"pytorch_lightning.utilities.argparse",
|
||||
"pytorch_lightning.utilities.cli",
|
||||
"pytorch_lightning.utilities.cloud_io",
|
||||
"pytorch_lightning.utilities.device_dtype_mixin",
|
||||
"pytorch_lightning.utilities.device_parser",
|
||||
"pytorch_lightning.utilities.parsing",
|
||||
|
|
|
@ -14,15 +14,20 @@
|
|||
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import IO, Union
|
||||
from typing import Any, Callable, Dict, IO, Optional, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
def load(path_or_url: Union[str, IO, Path], map_location=None):
|
||||
def load(
|
||||
path_or_url: Union[str, IO, Path],
|
||||
map_location: Optional[
|
||||
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
|
||||
] = None,
|
||||
) -> Any:
|
||||
if not isinstance(path_or_url, (str, Path)):
|
||||
# any sort of BytesIO or similiar
|
||||
return torch.load(path_or_url, map_location=map_location)
|
||||
|
@ -33,7 +38,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None):
|
|||
return torch.load(f, map_location=map_location)
|
||||
|
||||
|
||||
def get_filesystem(path: Union[str, Path]):
|
||||
def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem:
|
||||
path = str(path)
|
||||
if "://" in path:
|
||||
# use the fileystem from the protocol specified
|
||||
|
@ -42,7 +47,7 @@ def get_filesystem(path: Union[str, Path]):
|
|||
return LocalFileSystem()
|
||||
|
||||
|
||||
def atomic_save(checkpoint, filepath: str):
|
||||
def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
|
||||
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue