Load checkpoint from Bytes (#4314)
* load directly from fs * if not str or path * pep8 * type annotation Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
parent
3abfec8962
commit
91c64e9c82
|
@ -17,7 +17,7 @@ import csv
|
|||
import inspect
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
|
||||
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
|
||||
from warnings import warn
|
||||
|
||||
import fsspec
|
||||
|
@ -52,7 +52,7 @@ class ModelIO(object):
|
|||
@classmethod
|
||||
def load_from_checkpoint(
|
||||
cls,
|
||||
checkpoint_path: str,
|
||||
checkpoint_path: Union[str, IO],
|
||||
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
|
||||
hparams_file: Optional[str] = None,
|
||||
strict: bool = True,
|
||||
|
@ -65,7 +65,7 @@ class ModelIO(object):
|
|||
Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to checkpoint. This can also be a URL.
|
||||
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
|
||||
map_location:
|
||||
If your checkpoint saved a GPU model and you now load on CPUs
|
||||
or a different number of GPUs, use this to map to the new setup.
|
||||
|
|
|
@ -14,14 +14,17 @@
|
|||
|
||||
import io
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Union
|
||||
from typing import Union, IO
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import torch
|
||||
import fsspec
|
||||
|
||||
|
||||
def load(path_or_url: str, map_location=None):
|
||||
def load(path_or_url: Union[str, IO, Path], map_location=None):
|
||||
if not isinstance(path_or_url, (str, Path)):
|
||||
# any sort of BytesIO or similiar
|
||||
return torch.load(path_or_url, map_location=map_location)
|
||||
if path_or_url.startswith("http"):
|
||||
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
|
||||
fs = get_filesystem(path_or_url)
|
||||
|
|
Loading…
Reference in New Issue