diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index aa333a2994..da63750f59 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -20,11 +20,11 @@ import torch from torchmetrics import Metric import pytorch_lightning as pl -from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.utilities import _is_max_limit_reached +from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem +from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.migration import pl_legacy_patch @@ -53,7 +53,7 @@ class CheckpointConnector: if not os.path.isdir(self.trainer.weights_save_path): return None dir_path_hpc = str(self.trainer.weights_save_path) - max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") + max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_") if max_version is not None: return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") @@ -300,41 +300,6 @@ class CheckpointConnector: # PRIVATE OPS # ---------------------------------- - def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str: - # make sure the checkpoint folder exists - folderpath = str(folderpath) # because the tests pass a path object - fs = get_filesystem(folderpath) - fs.makedirs(folderpath, exist_ok=True) - - # save logger to make sure we get all the metrics - if logger: - logger.finalize("finished") - - max_suffix = self.max_ckpt_version_in_folder(folderpath) - ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 - - fs.makedirs(folderpath, exist_ok=True) - filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt") - - # give model a chance to do something on hpc_save - model = self.trainer.lightning_module - checkpoint = self.dump_checkpoint() - - # TODO: remove this in v1.8. - model.on_hpc_save(checkpoint) - - # do the actual save - # TODO: fix for anything with multiprocess DP, DDP, DDP2 - try: - atomic_save(checkpoint, filepath) - except AttributeError as err: - if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: - del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] - rank_zero_warn(f"warning, `hyper_parameters` dropped from checkpoint. An attribute is not picklable {err}") - atomic_save(checkpoint, filepath) - - return filepath - def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. Args: @@ -413,45 +378,13 @@ class CheckpointConnector: if self.trainer.datamodule is not None: self.trainer.datamodule.on_save_checkpoint(checkpoint) + # TODO: remove this in v1.8. + environment = self.trainer._accelerator_connector.cluster_environment + if isinstance(environment, SLURMEnvironment) and environment.auto_requeue: + model.on_hpc_save(checkpoint) + return checkpoint - def max_ckpt_version_in_folder(self, dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]: - """List up files in `dir_path` with `name_key`, then yield maximum suffix number. - - Args: - dir_path: path of directory which may contain files whose name include `name_key` - name_key: file name prefix - Returns: - None if no-corresponding-file else maximum suffix number - """ - - # check directory existence - fs = get_filesystem(dir_path) - if not fs.exists(dir_path): - return None - - # check corresponding file existence - files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)] - files = [x for x in files if name_key in x] - if len(files) == 0: - return None - - # extract suffix number - ckpt_vs = [] - for name in files: - name = name.split(name_key)[-1] - name = re.sub("[^0-9]", "", name) - ckpt_vs.append(int(name)) - - return max(ckpt_vs) - - def get_max_ckpt_path_from_folder(self, folder_path: _PATH) -> str: - """Get path of maximum-epoch checkpoint in the folder.""" - - max_suffix = self.max_ckpt_version_in_folder(folder_path) - ckpt_number = max_suffix if max_suffix is not None else 0 - return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt" - def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. @@ -489,3 +422,49 @@ class CheckpointConnector: "test_loop": self.trainer.test_loop.state_dict(), "predict_loop": self.trainer.predict_loop.state_dict(), } + + @staticmethod + def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]: + """List up files in `dir_path` with `name_key`, then yield maximum suffix number. + + Args: + dir_path: path of directory which may contain files whose name include `name_key` + name_key: file name prefix + Returns: + None if no-corresponding-file else maximum suffix number + """ + + # check directory existence + fs = get_filesystem(dir_path) + if not fs.exists(dir_path): + return None + + # check corresponding file existence + files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)] + files = [x for x in files if name_key in x] + if len(files) == 0: + return None + + # extract suffix number + ckpt_vs = [] + for name in files: + name = name.split(name_key)[-1] + name = re.sub("[^0-9]", "", name) + ckpt_vs.append(int(name)) + + return max(ckpt_vs) + + @staticmethod + def __get_max_ckpt_path_from_folder(folder_path: _PATH) -> str: + """Get path of maximum-epoch checkpoint in the folder.""" + + max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folder_path) + ckpt_number = max_suffix if max_suffix is not None else 0 + return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt" + + @staticmethod + def hpc_save_path(folderpath: _PATH) -> str: + max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folderpath) + ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 + filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt") + return filepath diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 8b3291ecfa..60ecec9e2a 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Set, Union import pytorch_lightning as pl from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.utilities.distributed import rank_zero_info from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS # copied from signal.pyi @@ -62,11 +63,15 @@ class SignalConnector: self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None: - if self.trainer.is_global_zero: - # save weights - log.info("handling SIGUSR1") - self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) + rank_zero_info("handling SIGUSR1") + # save logger to make sure we get all the metrics + if self.trainer.logger: + self.trainer.logger.finalize("finished") + hpc_save_path = self.trainer.checkpoint_connector.hpc_save_path(self.trainer.weights_save_path) + self.trainer.save_checkpoint(hpc_save_path) + + if self.trainer.is_global_zero: # find job id job_id = os.environ["SLURM_JOB_ID"] cmd = ["scontrol", "requeue", job_id] @@ -88,10 +93,6 @@ class SignalConnector: else: log.warning("requeue failed...") - # close experiment to avoid issues - if self.trainer.logger: - self.trainer.logger.finalize("finished") - def fault_tolerant_sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None: log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.") self.trainer._terminate_gracefully = True diff --git a/tests/helpers/pipelines.py b/tests/helpers/pipelines.py index 2f7d2d584d..71908d1a5b 100644 --- a/tests/helpers/pipelines.py +++ b/tests/helpers/pipelines.py @@ -82,9 +82,13 @@ def run_model_test( if with_hpc: # test HPC saving - trainer.checkpoint_connector.hpc_save(save_dir, logger) + # save logger to make sure we get all the metrics + if logger: + logger.finalize("finished") + hpc_save_path = trainer.checkpoint_connector.hpc_save_path(save_dir) + trainer.save_checkpoint(hpc_save_path) # test HPC loading - checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir) + checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir) trainer.checkpoint_connector.restore(checkpoint_path) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index f09257c83b..145ef3d954 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -90,9 +90,13 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True): pretrained_model(batch) # test HPC saving - trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) + # save logger to make sure we get all the metrics + if trainer.logger: + trainer.logger.finalize("finished") + hpc_save_path = trainer.checkpoint_connector.hpc_save_path(ckpt_path) + trainer.save_checkpoint(hpc_save_path) # test HPC loading - checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path) + checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(ckpt_path) trainer.checkpoint_connector.restore(checkpoint_path) if on_gpu: diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index f1602c7bb1..a399967247 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -63,8 +63,12 @@ def test_cpu_slurm_save_load(tmpdir): # test HPC saving # simulate snapshot on slurm - saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger) - assert os.path.exists(saved_filepath) + # save logger to make sure we get all the metrics + if logger: + logger.finalize("finished") + hpc_save_path = trainer.checkpoint_connector.hpc_save_path(trainer.weights_save_path) + trainer.save_checkpoint(hpc_save_path) + assert os.path.exists(hpc_save_path) # new logger file to get meta logger = tutils.get_default_logger(tmpdir, version=version) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 9d6ecb08cd..d6511d8db3 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -522,7 +522,11 @@ def test_dp_resume(tmpdir): # HPC LOAD/SAVE # --------------------------- # save - trainer.checkpoint_connector.hpc_save(tmpdir, logger) + # save logger to make sure we get all the metrics + if logger: + logger.finalize("finished") + hpc_save_path = trainer.checkpoint_connector.hpc_save_path(tmpdir) + trainer.save_checkpoint(hpc_save_path) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index c0153638cc..c3d0f85958 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -20,11 +20,13 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from tests.helpers import BoringModel -class HPCHookdedModel(BoringModel): +# TODO: remove HPCHookedModel in v1.8 +class HPCHookedModel(BoringModel): def __init__(self): super().__init__() self.hpc_save_called = 0 @@ -39,15 +41,25 @@ class HPCHookdedModel(BoringModel): self.hpc_load_called += 1 -def test_hpc_hook_calls(tmpdir): - model = HPCHookdedModel() +# TODO: remove test_hpc_hook_calls in v1.8 +@mock.patch( + "pytorch_lightning.trainer.connectors.accelerator_connector.AcceleratorConnector._is_slurm_managing_tasks", + return_value=True, +) +def test_hpc_hook_calls(mock_slurm_env, tmpdir): + model = HPCHookedModel() trainer = Trainer(default_root_dir=tmpdir, max_steps=1, enable_checkpointing=False, logger=False) + environment = trainer._accelerator_connector.cluster_environment + assert isinstance(environment, SLURMEnvironment) + assert environment.auto_requeue with pytest.deprecated_call( match=r"Method `LightningModule.on_hpc_save` is deprecated in v1.6 and will be removed in v1.8." ): trainer.fit(model) - connector = trainer.checkpoint_connector - connector.hpc_save(tmpdir, logger=Mock()) + + # simulate snapshot on slurm + hpc_save_path = trainer.checkpoint_connector.hpc_save_path(tmpdir) + trainer.save_checkpoint(hpc_save_path) assert model.hpc_save_called == 1 assert model.hpc_load_called == 0 @@ -134,8 +146,11 @@ def test_hpc_max_ckpt_version(tmpdir): trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt") assert trainer.checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt") - assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33 - assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None + assert trainer.checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir) == 33 + assert ( + trainer.checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir / "not" / "existing") + is None + ) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})