Fix mypy typing for `utilities.cloud_io.py` (#8671)
Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
parent
8274183bf2
commit
08ac079c2f
|
@ -64,6 +64,7 @@ module = [
|
||||||
"pytorch_lightning.trainer.connectors.logger_connector",
|
"pytorch_lightning.trainer.connectors.logger_connector",
|
||||||
"pytorch_lightning.utilities.argparse",
|
"pytorch_lightning.utilities.argparse",
|
||||||
"pytorch_lightning.utilities.cli",
|
"pytorch_lightning.utilities.cli",
|
||||||
|
"pytorch_lightning.utilities.cloud_io",
|
||||||
"pytorch_lightning.utilities.device_dtype_mixin",
|
"pytorch_lightning.utilities.device_dtype_mixin",
|
||||||
"pytorch_lightning.utilities.device_parser",
|
"pytorch_lightning.utilities.device_parser",
|
||||||
"pytorch_lightning.utilities.parsing",
|
"pytorch_lightning.utilities.parsing",
|
||||||
|
|
|
@ -14,15 +14,20 @@
|
||||||
|
|
||||||
import io
|
import io
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import IO, Union
|
from typing import Any, Callable, Dict, IO, Optional, Union
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import torch
|
import torch
|
||||||
from fsspec.implementations.local import LocalFileSystem
|
from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_url: Union[str, IO, Path], map_location=None):
|
def load(
|
||||||
|
path_or_url: Union[str, IO, Path],
|
||||||
|
map_location: Optional[
|
||||||
|
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
|
||||||
|
] = None,
|
||||||
|
) -> Any:
|
||||||
if not isinstance(path_or_url, (str, Path)):
|
if not isinstance(path_or_url, (str, Path)):
|
||||||
# any sort of BytesIO or similiar
|
# any sort of BytesIO or similiar
|
||||||
return torch.load(path_or_url, map_location=map_location)
|
return torch.load(path_or_url, map_location=map_location)
|
||||||
|
@ -33,7 +38,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None):
|
||||||
return torch.load(f, map_location=map_location)
|
return torch.load(f, map_location=map_location)
|
||||||
|
|
||||||
|
|
||||||
def get_filesystem(path: Union[str, Path]):
|
def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem:
|
||||||
path = str(path)
|
path = str(path)
|
||||||
if "://" in path:
|
if "://" in path:
|
||||||
# use the fileystem from the protocol specified
|
# use the fileystem from the protocol specified
|
||||||
|
@ -42,7 +47,7 @@ def get_filesystem(path: Union[str, Path]):
|
||||||
return LocalFileSystem()
|
return LocalFileSystem()
|
||||||
|
|
||||||
|
|
||||||
def atomic_save(checkpoint, filepath: str):
|
def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
|
||||||
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
|
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
Loading…
Reference in New Issue