Remove `hpc_save` (#11101)
This commit is contained in:
parent
7637550ab5
commit
4b5761539e
|
@ -20,11 +20,11 @@ import torch
|
|||
from torchmetrics import Metric
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.loops.utilities import _is_max_limit_reached
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.migration import pl_legacy_patch
|
||||
|
@ -53,7 +53,7 @@ class CheckpointConnector:
|
|||
if not os.path.isdir(self.trainer.weights_save_path):
|
||||
return None
|
||||
dir_path_hpc = str(self.trainer.weights_save_path)
|
||||
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
|
||||
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
|
||||
if max_version is not None:
|
||||
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
|
||||
|
||||
|
@ -300,41 +300,6 @@ class CheckpointConnector:
|
|||
# PRIVATE OPS
|
||||
# ----------------------------------
|
||||
|
||||
def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str:
|
||||
# make sure the checkpoint folder exists
|
||||
folderpath = str(folderpath) # because the tests pass a path object
|
||||
fs = get_filesystem(folderpath)
|
||||
fs.makedirs(folderpath, exist_ok=True)
|
||||
|
||||
# save logger to make sure we get all the metrics
|
||||
if logger:
|
||||
logger.finalize("finished")
|
||||
|
||||
max_suffix = self.max_ckpt_version_in_folder(folderpath)
|
||||
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
|
||||
|
||||
fs.makedirs(folderpath, exist_ok=True)
|
||||
filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")
|
||||
|
||||
# give model a chance to do something on hpc_save
|
||||
model = self.trainer.lightning_module
|
||||
checkpoint = self.dump_checkpoint()
|
||||
|
||||
# TODO: remove this in v1.8.
|
||||
model.on_hpc_save(checkpoint)
|
||||
|
||||
# do the actual save
|
||||
# TODO: fix for anything with multiprocess DP, DDP, DDP2
|
||||
try:
|
||||
atomic_save(checkpoint, filepath)
|
||||
except AttributeError as err:
|
||||
if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn(f"warning, `hyper_parameters` dropped from checkpoint. An attribute is not picklable {err}")
|
||||
atomic_save(checkpoint, filepath)
|
||||
|
||||
return filepath
|
||||
|
||||
def dump_checkpoint(self, weights_only: bool = False) -> dict:
|
||||
"""Creating a model checkpoint dictionary object from various component states.
|
||||
Args:
|
||||
|
@ -413,45 +378,13 @@ class CheckpointConnector:
|
|||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.on_save_checkpoint(checkpoint)
|
||||
|
||||
# TODO: remove this in v1.8.
|
||||
environment = self.trainer._accelerator_connector.cluster_environment
|
||||
if isinstance(environment, SLURMEnvironment) and environment.auto_requeue:
|
||||
model.on_hpc_save(checkpoint)
|
||||
|
||||
return checkpoint
|
||||
|
||||
def max_ckpt_version_in_folder(self, dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
|
||||
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
|
||||
|
||||
Args:
|
||||
dir_path: path of directory which may contain files whose name include `name_key`
|
||||
name_key: file name prefix
|
||||
Returns:
|
||||
None if no-corresponding-file else maximum suffix number
|
||||
"""
|
||||
|
||||
# check directory existence
|
||||
fs = get_filesystem(dir_path)
|
||||
if not fs.exists(dir_path):
|
||||
return None
|
||||
|
||||
# check corresponding file existence
|
||||
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
|
||||
files = [x for x in files if name_key in x]
|
||||
if len(files) == 0:
|
||||
return None
|
||||
|
||||
# extract suffix number
|
||||
ckpt_vs = []
|
||||
for name in files:
|
||||
name = name.split(name_key)[-1]
|
||||
name = re.sub("[^0-9]", "", name)
|
||||
ckpt_vs.append(int(name))
|
||||
|
||||
return max(ckpt_vs)
|
||||
|
||||
def get_max_ckpt_path_from_folder(self, folder_path: _PATH) -> str:
|
||||
"""Get path of maximum-epoch checkpoint in the folder."""
|
||||
|
||||
max_suffix = self.max_ckpt_version_in_folder(folder_path)
|
||||
ckpt_number = max_suffix if max_suffix is not None else 0
|
||||
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"
|
||||
|
||||
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
|
||||
"""Save model/training states as a checkpoint file through state-dump and file-write.
|
||||
|
||||
|
@ -489,3 +422,49 @@ class CheckpointConnector:
|
|||
"test_loop": self.trainer.test_loop.state_dict(),
|
||||
"predict_loop": self.trainer.predict_loop.state_dict(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
|
||||
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
|
||||
|
||||
Args:
|
||||
dir_path: path of directory which may contain files whose name include `name_key`
|
||||
name_key: file name prefix
|
||||
Returns:
|
||||
None if no-corresponding-file else maximum suffix number
|
||||
"""
|
||||
|
||||
# check directory existence
|
||||
fs = get_filesystem(dir_path)
|
||||
if not fs.exists(dir_path):
|
||||
return None
|
||||
|
||||
# check corresponding file existence
|
||||
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
|
||||
files = [x for x in files if name_key in x]
|
||||
if len(files) == 0:
|
||||
return None
|
||||
|
||||
# extract suffix number
|
||||
ckpt_vs = []
|
||||
for name in files:
|
||||
name = name.split(name_key)[-1]
|
||||
name = re.sub("[^0-9]", "", name)
|
||||
ckpt_vs.append(int(name))
|
||||
|
||||
return max(ckpt_vs)
|
||||
|
||||
@staticmethod
|
||||
def __get_max_ckpt_path_from_folder(folder_path: _PATH) -> str:
|
||||
"""Get path of maximum-epoch checkpoint in the folder."""
|
||||
|
||||
max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folder_path)
|
||||
ckpt_number = max_suffix if max_suffix is not None else 0
|
||||
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"
|
||||
|
||||
@staticmethod
|
||||
def hpc_save_path(folderpath: _PATH) -> str:
|
||||
max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folderpath)
|
||||
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
|
||||
filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")
|
||||
return filepath
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Set, Union
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS
|
||||
|
||||
# copied from signal.pyi
|
||||
|
@ -62,11 +63,15 @@ class SignalConnector:
|
|||
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
|
||||
|
||||
def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
|
||||
if self.trainer.is_global_zero:
|
||||
# save weights
|
||||
log.info("handling SIGUSR1")
|
||||
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger)
|
||||
rank_zero_info("handling SIGUSR1")
|
||||
|
||||
# save logger to make sure we get all the metrics
|
||||
if self.trainer.logger:
|
||||
self.trainer.logger.finalize("finished")
|
||||
hpc_save_path = self.trainer.checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
|
||||
self.trainer.save_checkpoint(hpc_save_path)
|
||||
|
||||
if self.trainer.is_global_zero:
|
||||
# find job id
|
||||
job_id = os.environ["SLURM_JOB_ID"]
|
||||
cmd = ["scontrol", "requeue", job_id]
|
||||
|
@ -88,10 +93,6 @@ class SignalConnector:
|
|||
else:
|
||||
log.warning("requeue failed...")
|
||||
|
||||
# close experiment to avoid issues
|
||||
if self.trainer.logger:
|
||||
self.trainer.logger.finalize("finished")
|
||||
|
||||
def fault_tolerant_sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
|
||||
log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.")
|
||||
self.trainer._terminate_gracefully = True
|
||||
|
|
|
@ -82,9 +82,13 @@ def run_model_test(
|
|||
|
||||
if with_hpc:
|
||||
# test HPC saving
|
||||
trainer.checkpoint_connector.hpc_save(save_dir, logger)
|
||||
# save logger to make sure we get all the metrics
|
||||
if logger:
|
||||
logger.finalize("finished")
|
||||
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(save_dir)
|
||||
trainer.save_checkpoint(hpc_save_path)
|
||||
# test HPC loading
|
||||
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
|
||||
checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir)
|
||||
trainer.checkpoint_connector.restore(checkpoint_path)
|
||||
|
||||
|
||||
|
|
|
@ -90,9 +90,13 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True):
|
|||
pretrained_model(batch)
|
||||
|
||||
# test HPC saving
|
||||
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
|
||||
# save logger to make sure we get all the metrics
|
||||
if trainer.logger:
|
||||
trainer.logger.finalize("finished")
|
||||
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(ckpt_path)
|
||||
trainer.save_checkpoint(hpc_save_path)
|
||||
# test HPC loading
|
||||
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
|
||||
checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(ckpt_path)
|
||||
trainer.checkpoint_connector.restore(checkpoint_path)
|
||||
|
||||
if on_gpu:
|
||||
|
|
|
@ -63,8 +63,12 @@ def test_cpu_slurm_save_load(tmpdir):
|
|||
|
||||
# test HPC saving
|
||||
# simulate snapshot on slurm
|
||||
saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger)
|
||||
assert os.path.exists(saved_filepath)
|
||||
# save logger to make sure we get all the metrics
|
||||
if logger:
|
||||
logger.finalize("finished")
|
||||
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(trainer.weights_save_path)
|
||||
trainer.save_checkpoint(hpc_save_path)
|
||||
assert os.path.exists(hpc_save_path)
|
||||
|
||||
# new logger file to get meta
|
||||
logger = tutils.get_default_logger(tmpdir, version=version)
|
||||
|
|
|
@ -522,7 +522,11 @@ def test_dp_resume(tmpdir):
|
|||
# HPC LOAD/SAVE
|
||||
# ---------------------------
|
||||
# save
|
||||
trainer.checkpoint_connector.hpc_save(tmpdir, logger)
|
||||
# save logger to make sure we get all the metrics
|
||||
if logger:
|
||||
logger.finalize("finished")
|
||||
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(tmpdir)
|
||||
trainer.save_checkpoint(hpc_save_path)
|
||||
|
||||
# init new trainer
|
||||
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
|
||||
|
|
|
@ -20,11 +20,13 @@ import torch
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
class HPCHookdedModel(BoringModel):
|
||||
# TODO: remove HPCHookedModel in v1.8
|
||||
class HPCHookedModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.hpc_save_called = 0
|
||||
|
@ -39,15 +41,25 @@ class HPCHookdedModel(BoringModel):
|
|||
self.hpc_load_called += 1
|
||||
|
||||
|
||||
def test_hpc_hook_calls(tmpdir):
|
||||
model = HPCHookdedModel()
|
||||
# TODO: remove test_hpc_hook_calls in v1.8
|
||||
@mock.patch(
|
||||
"pytorch_lightning.trainer.connectors.accelerator_connector.AcceleratorConnector._is_slurm_managing_tasks",
|
||||
return_value=True,
|
||||
)
|
||||
def test_hpc_hook_calls(mock_slurm_env, tmpdir):
|
||||
model = HPCHookedModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, enable_checkpointing=False, logger=False)
|
||||
environment = trainer._accelerator_connector.cluster_environment
|
||||
assert isinstance(environment, SLURMEnvironment)
|
||||
assert environment.auto_requeue
|
||||
with pytest.deprecated_call(
|
||||
match=r"Method `LightningModule.on_hpc_save` is deprecated in v1.6 and will be removed in v1.8."
|
||||
):
|
||||
trainer.fit(model)
|
||||
connector = trainer.checkpoint_connector
|
||||
connector.hpc_save(tmpdir, logger=Mock())
|
||||
|
||||
# simulate snapshot on slurm
|
||||
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(tmpdir)
|
||||
trainer.save_checkpoint(hpc_save_path)
|
||||
assert model.hpc_save_called == 1
|
||||
assert model.hpc_load_called == 0
|
||||
|
||||
|
@ -134,8 +146,11 @@ def test_hpc_max_ckpt_version(tmpdir):
|
|||
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")
|
||||
|
||||
assert trainer.checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
|
||||
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33
|
||||
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None
|
||||
assert trainer.checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir) == 33
|
||||
assert (
|
||||
trainer.checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir / "not" / "existing")
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
|
|
Loading…
Reference in New Issue