Fix usage of fs.listdir in CheckpointConnector (#15413)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
This commit is contained in:
Yuxuan Lu 2022-11-05 04:21:52 +08:00 committed by GitHub
parent 62d040c383
commit ee8a57da0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 6 deletions

View File

@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an attribute error in `ColossalAIStrategy` at import time when `torch.distributed` is not available ([#15535](https://github.com/Lightning-AI/lightning/pull/15535))
- Fixed an issue when calling `fs.listdir` with file URI instead of path in `CheckpointConnector` ([#15413](https://github.com/Lightning-AI/lightning/pull/15413))
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))

View File

@ -19,6 +19,8 @@ from copy import deepcopy
from typing import Any, Dict, Optional
import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import LocalFileSystem
from torch import Tensor
from torchmetrics import Metric
@ -59,13 +61,16 @@ class CheckpointConnector:
@property
def _hpc_resume_path(self) -> Optional[str]:
dir_path_hpc = self.trainer.default_root_dir
fs = get_filesystem(dir_path_hpc)
if not fs.isdir(dir_path_hpc):
return None
dir_path_hpc = str(dir_path_hpc)
fs, path = url_to_fs(dir_path_hpc)
if not fs.isdir(path):
return None
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_version is not None:
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
if isinstance(fs, LocalFileSystem):
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
else:
return dir_path_hpc + fs.sep + f"hpc_ckpt_{max_version}.ckpt"
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
@ -565,12 +570,12 @@ class CheckpointConnector:
"""
# check directory existence
fs = get_filesystem(dir_path)
fs, uri = url_to_fs(str(dir_path))
if not fs.exists(dir_path):
return None
# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
files = [os.path.basename(f["name"]) for f in fs.listdir(uri)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return None

View File

@ -102,6 +102,30 @@ def test_hpc_max_ckpt_version(tmpdir):
)
def test_ckpt_for_fsspec():
"""Test that the CheckpointConnector is able to write to fsspec file systems."""
model = BoringModel()
# hardcoding dir since `tmpdir` can be windows path
trainer = Trainer(
default_root_dir="memory://test_ckpt_for_fsspec", limit_train_batches=1, limit_val_batches=1, max_epochs=1
)
trainer.fit(model)
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt.ckpt")
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_0.ckpt")
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_3.ckpt")
trainer.save_checkpoint("memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt")
assert trainer._checkpoint_connector._hpc_resume_path == "memory://test_ckpt_for_fsspec/hpc_ckpt_33.ckpt"
assert (
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://test_ckpt_for_fsspec")
== 33
)
assert (
trainer._checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder("memory://not_existing") is None
)
def test_loops_restore(tmpdir):
"""Test that required loop state_dict is loaded correctly by checkpoint connector."""
model = BoringModel()