diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index c07dd944ee..7124007cd3 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,7 +17,7 @@ import csv import inspect import os from argparse import Namespace -from typing import Union, Dict, Any, Optional, Callable, MutableMapping +from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO from warnings import warn import fsspec @@ -52,7 +52,7 @@ class ModelIO(object): @classmethod def load_from_checkpoint( cls, - checkpoint_path: str, + checkpoint_path: Union[str, IO], map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, strict: bool = True, @@ -65,7 +65,7 @@ class ModelIO(object): Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`. Args: - checkpoint_path: Path to checkpoint. This can also be a URL. + checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object map_location: If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 863cd3617d..33845384fa 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -14,14 +14,17 @@ import io from distutils.version import LooseVersion -from typing import Union +from typing import Union, IO from pathlib import Path from urllib.parse import urlparse import torch import fsspec -def load(path_or_url: str, map_location=None): +def load(path_or_url: Union[str, IO, Path], map_location=None): + if not isinstance(path_or_url, (str, Path)): + # any sort of BytesIO or similiar + return torch.load(path_or_url, map_location=map_location) if path_or_url.startswith("http"): return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) fs = get_filesystem(path_or_url)