diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f0ce0009d..460f5ce69a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Swap `torch.load` for `fsspec` load in DDP spawn backend ([#3787](https://github.com/PyTorchLightning/pytorch-lightning/pull/3787)) +- Swap `torch.load` for `fsspec` load in cloud_io loading ([#3692](https://github.com/PyTorchLightning/pytorch-lightning/pull/3692)) + ### Deprecated diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 2c6771c5d6..e053e6caca 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -25,9 +25,11 @@ pathlike = Union[Path, str] def load(path_or_url: str, map_location=None): - if urlparse(path_or_url).scheme == "" or Path(path_or_url).drive: # no scheme or with a drive letter - return torch.load(path_or_url, map_location=map_location) - return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) + if path_or_url.startswith("http"): + return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) + fs = get_filesystem(path_or_url) + with fs.open(path_or_url, "rb") as f: + return torch.load(f, map_location=map_location) def get_filesystem(path: pathlike):