Fix usage of fs.listdir in CheckpointConnector (#15413)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
This commit is contained in:
parent
62d040c383
commit
ee8a57da0f
|
@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed an attribute error in `ColossalAIStrategy` at import time when `torch.distributed` is not available ([#15535](https://github.com/Lightning-AI/lightning/pull/15535))
|
||||
|
||||
- Fixed an issue when calling `fs.listdir` with file URI instead of path in `CheckpointConnector` ([#15413](https://github.com/Lightning-AI/lightning/pull/15413))
|
||||
|
||||
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@ from copy import deepcopy
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from fsspec.core import url_to_fs
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
from torch import Tensor
|
||||
from torchmetrics import Metric
|
||||
|
||||
|
@ -59,13 +61,16 @@ class CheckpointConnector:
|
|||
@property
|
||||
def _hpc_resume_path(self) -> Optional[str]:
|
||||
dir_path_hpc = self.trainer.default_root_dir
|
||||
fs = get_filesystem(dir_path_hpc)
|
||||
if not fs.isdir(dir_path_hpc):
|
||||
return None
|
||||
dir_path_hpc = str(dir_path_hpc)
|
||||
fs, path = url_to_fs(dir_path_hpc)
|
||||
if not fs.isdir(path):
|
||||
return None
|
||||
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
|
||||
if max_version is not None:
|
||||
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
|
||||
if isinstance(fs, LocalFileSystem):
|
||||
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
|
||||
else:
|
||||
return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt"
|
||||
|
||||
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
|
||||
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
|
||||
|
@ -565,12 +570,12 @@ class CheckpointConnector:
|
|||
"""
|
||||
|
||||
# check directory existence
|
||||
fs = get_filesystem(dir_path)
|
||||
fs, uri = url_to_fs(str(dir_path))
|
||||
if not fs.exists(dir_path):
|
||||
return None
|
||||
|
||||
# check corresponding file existence
|
||||
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
|
||||
files = [os.path.basename(f["name"]) for f in fs.listdir(uri)]
|
||||
files = [x for x in files if name_key in x]
|
||||
if len(files) == 0:
|
||||
return None
|
||||
|
|
|
@ -102,6 +102,30 @@ def test_hpc_max_ckpt_version(tmpdir):
|
|||
)
|
||||
|
||||
|
||||
def test_ckpt_for_fsspec():
|
||||
"""Test that the CheckpointConnector is able to write to fsspec file systems."""
|
||||
|
||||
model = BoringModel()
|
||||
# hardcoding dir since `tmpdir` can be windows path
|
||||
trainer = Trainer(
|
||||
default_root_dir="memory://test_ckpt_for_fsspec", limit_train_batches=1, limit_val_batches=1, max_epochs=1
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt.ckpt")
|
||||
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_0.ckpt")
|
||||
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_3.ckpt")
|
||||
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt")
|
||||
|
||||
assert trainer._checkpoint_connector._hpc_resume_path == "memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt"
|
||||
assert (
|
||||
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://test_ckpt_for_fsspec")
|
||||
== 33
|
||||
)
|
||||
assert (
|
||||
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://not_existing") is None
|
||||
)
|
||||
|
||||
|
||||
def test_loops_restore(tmpdir):
|
||||
"""Test that required loop state_dict is loaded correctly by checkpoint connector."""
|
||||
model = BoringModel()
|
||||
|
|
Loading…
Reference in New Issue