Load checkpoint from Bytes (#4314)

* load directly from fs

* if not str or path

* pep8

* type annotation

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
Teddy Koker 2020-10-23 11:29:13 -04:00 committed by GitHub
parent 3abfec8962
commit 91c64e9c82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 5 deletions

View File

@ -17,7 +17,7 @@ import csv
import inspect
import os
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
from warnings import warn
import fsspec
@ -52,7 +52,7 @@ class ModelIO(object):
@classmethod
def load_from_checkpoint(
cls,
checkpoint_path: str,
checkpoint_path: Union[str, IO],
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
hparams_file: Optional[str] = None,
strict: bool = True,
@ -65,7 +65,7 @@ class ModelIO(object):
Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`.
Args:
checkpoint_path: Path to checkpoint. This can also be a URL.
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
map_location:
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.

View File

@ -14,14 +14,17 @@
import io
from distutils.version import LooseVersion
from typing import Union
from typing import Union, IO
from pathlib import Path
from urllib.parse import urlparse
import torch
import fsspec
def load(path_or_url: str, map_location=None):
def load(path_or_url: Union[str, IO, Path], map_location=None):
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similiar
return torch.load(path_or_url, map_location=map_location)
if path_or_url.startswith("http"):
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
fs = get_filesystem(path_or_url)