From 06e2635c719c7552b4d175726d21afe7c2df930b Mon Sep 17 00:00:00 2001 From: DuYicong515 Date: Wed, 2 Feb 2022 17:55:24 -0800 Subject: [PATCH] Refactor get_filesystem to use native fsspec API (#11708) --- pytorch_lightning/utilities/cloud_io.py | 15 +++++------ tests/utilities/test_cloud_io.py | 34 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 8 deletions(-) create mode 100644 tests/utilities/test_cloud_io.py diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 60e1c92d64..446726f2fe 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -19,7 +19,10 @@ from typing import Any, Callable, Dict, IO, Optional, Union import fsspec import torch -from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem +from fsspec.core import url_to_fs +from fsspec.implementations.local import AbstractFileSystem + +from pytorch_lightning.utilities.types import _PATH def load( @@ -44,13 +47,9 @@ def load( return torch.load(f, map_location=map_location) -def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem: - path = str(path) - if "://" in path: - # use the fileystem from the protocol specified - return fsspec.filesystem(path.split(":", 1)[0]) - # use local filesystem - return LocalFileSystem() +def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: + fs, _ = url_to_fs(str(path), **kwargs) + return fs def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None: diff --git a/tests/utilities/test_cloud_io.py b/tests/utilities/test_cloud_io.py new file mode 100644 index 0000000000..b2cbd5bead --- /dev/null +++ b/tests/utilities/test_cloud_io.py @@ -0,0 +1,34 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import fsspec +from fsspec.implementations.local import LocalFileSystem + +from pytorch_lightning.utilities.cloud_io import get_filesystem + + +def test_get_filesystem_custom_filesystem(): + _DUMMY_PRFEIX = "dummy" + + class DummyFileSystem(LocalFileSystem): + ... + + fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True) + output_file = os.path.join(f"{_DUMMY_PRFEIX}://", "tmpdir/tmp_file") + assert isinstance(get_filesystem(output_file), DummyFileSystem) + + +def test_get_filesystem_local_filesystem(): + assert isinstance(get_filesystem("tmpdir/tmp_file"), LocalFileSystem)