diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 3bccaee208..6a600a058a 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an attribute error in `ColossalAIStrategy` at import time when `torch.distributed` is not available ([#15535](https://github.com/Lightning-AI/lightning/pull/15535)) +- Fixed an issue when calling `fs.listdir` with file URI instead of path in `CheckpointConnector` ([#15413](https://github.com/Lightning-AI/lightning/pull/15413)) + - Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063)) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index bff97a7fbc..a3aa02dd6e 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -19,6 +19,8 @@ from copy import deepcopy from typing import Any, Dict, Optional import torch +from fsspec.core import url_to_fs +from fsspec.implementations.local import LocalFileSystem from torch import Tensor from torchmetrics import Metric @@ -59,13 +61,16 @@ class CheckpointConnector: @property def _hpc_resume_path(self) -> Optional[str]: dir_path_hpc = self.trainer.default_root_dir - fs = get_filesystem(dir_path_hpc) - if not fs.isdir(dir_path_hpc): - return None dir_path_hpc = str(dir_path_hpc) + fs, path = url_to_fs(dir_path_hpc) + if not fs.isdir(path): + return None max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") if max_version is not None: - return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") + if isinstance(fs, LocalFileSystem): + return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") + else: + return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt" def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: @@ -565,12 +570,12 @@ class CheckpointConnector: """ # check directory existence - fs = get_filesystem(dir_path) + fs, uri = url_to_fs(str(dir_path)) if not fs.exists(dir_path): return None # check corresponding file existence - files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)] + files = [os.path.basename(f["name"]) for f in fs.listdir(uri)] files = [x for x in files if name_key in x] if len(files) == 0: return None diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 349bf93017..9d69ad1bd3 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -102,6 +102,30 @@ def test_hpc_max_ckpt_version(tmpdir): ) +def test_ckpt_for_fsspec(): + """Test that the CheckpointConnector is able to write to fsspec file systems.""" + + model = BoringModel() + # hardcoding dir since `tmpdir` can be windows path + trainer = Trainer( + default_root_dir="memory://test_ckpt_for_fsspec", limit_train_batches=1, limit_val_batches=1, max_epochs=1 + ) + trainer.fit(model) + trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt.ckpt") + trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_0.ckpt") + trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_3.ckpt") + trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt") + + assert trainer._checkpoint_connector._hpc_resume_path == "memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt" + assert ( + trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://test_ckpt_for_fsspec") + == 33 + ) + assert ( + trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://not_existing") is None + ) + + def test_loops_restore(tmpdir): """Test that required loop state_dict is loaded correctly by checkpoint connector.""" model = BoringModel()