Remove `hpc_save` (#11101)

This commit is contained in:
jjenniferdai 2022-01-03 04:23:13 -08:00 committed by GitHub
parent 7637550ab5
commit 4b5761539e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 97 deletions

View File

@ -20,11 +20,11 @@ import torch
from torchmetrics import Metric
import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.migration import pl_legacy_patch
@ -53,7 +53,7 @@ class CheckpointConnector:
if not os.path.isdir(self.trainer.weights_save_path):
return None
dir_path_hpc = str(self.trainer.weights_save_path)
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_version is not None:
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")
@ -300,41 +300,6 @@ class CheckpointConnector:
# PRIVATE OPS
# ----------------------------------
def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str:
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
fs = get_filesystem(folderpath)
fs.makedirs(folderpath, exist_ok=True)
# save logger to make sure we get all the metrics
if logger:
logger.finalize("finished")
max_suffix = self.max_ckpt_version_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
fs.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")
# give model a chance to do something on hpc_save
model = self.trainer.lightning_module
checkpoint = self.dump_checkpoint()
# TODO: remove this in v1.8.
model.on_hpc_save(checkpoint)
# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
if pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(f"warning, `hyper_parameters` dropped from checkpoint. An attribute is not picklable {err}")
atomic_save(checkpoint, filepath)
return filepath
def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
Args:
@ -413,45 +378,13 @@ class CheckpointConnector:
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_save_checkpoint(checkpoint)
# TODO: remove this in v1.8.
environment = self.trainer._accelerator_connector.cluster_environment
if isinstance(environment, SLURMEnvironment) and environment.auto_requeue:
model.on_hpc_save(checkpoint)
return checkpoint
def max_ckpt_version_in_folder(self, dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Args:
dir_path: path of directory which may contain files whose name include `name_key`
name_key: file name prefix
Returns:
None if no-corresponding-file else maximum suffix number
"""
# check directory existence
fs = get_filesystem(dir_path)
if not fs.exists(dir_path):
return None
# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return None
# extract suffix number
ckpt_vs = []
for name in files:
name = name.split(name_key)[-1]
name = re.sub("[^0-9]", "", name)
ckpt_vs.append(int(name))
return max(ckpt_vs)
def get_max_ckpt_path_from_folder(self, folder_path: _PATH) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""
max_suffix = self.max_ckpt_version_in_folder(folder_path)
ckpt_number = max_suffix if max_suffix is not None else 0
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"
def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
@ -489,3 +422,49 @@ class CheckpointConnector:
"test_loop": self.trainer.test_loop.state_dict(),
"predict_loop": self.trainer.predict_loop.state_dict(),
}
@staticmethod
def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Args:
dir_path: path of directory which may contain files whose name include `name_key`
name_key: file name prefix
Returns:
None if no-corresponding-file else maximum suffix number
"""
# check directory existence
fs = get_filesystem(dir_path)
if not fs.exists(dir_path):
return None
# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
files = [x for x in files if name_key in x]
if len(files) == 0:
return None
# extract suffix number
ckpt_vs = []
for name in files:
name = name.split(name_key)[-1]
name = re.sub("[^0-9]", "", name)
ckpt_vs.append(int(name))
return max(ckpt_vs)
@staticmethod
def __get_max_ckpt_path_from_folder(folder_path: _PATH) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""
max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folder_path)
ckpt_number = max_suffix if max_suffix is not None else 0
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"
@staticmethod
def hpc_save_path(folderpath: _PATH) -> str:
max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")
return filepath

View File

@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Set, Union
import pytorch_lightning as pl
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities.distributed import rank_zero_info
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS
# copied from signal.pyi
@ -62,11 +63,15 @@ class SignalConnector:
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
if self.trainer.is_global_zero:
# save weights
log.info("handling SIGUSR1")
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger)
rank_zero_info("handling SIGUSR1")
# save logger to make sure we get all the metrics
if self.trainer.logger:
self.trainer.logger.finalize("finished")
hpc_save_path = self.trainer.checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
self.trainer.save_checkpoint(hpc_save_path)
if self.trainer.is_global_zero:
# find job id
job_id = os.environ["SLURM_JOB_ID"]
cmd = ["scontrol", "requeue", job_id]
@ -88,10 +93,6 @@ class SignalConnector:
else:
log.warning("requeue failed...")
# close experiment to avoid issues
if self.trainer.logger:
self.trainer.logger.finalize("finished")
def fault_tolerant_sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.")
self.trainer._terminate_gracefully = True

View File

@ -82,9 +82,13 @@ def run_model_test(
if with_hpc:
# test HPC saving
trainer.checkpoint_connector.hpc_save(save_dir, logger)
# save logger to make sure we get all the metrics
if logger:
logger.finalize("finished")
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(save_dir)
trainer.save_checkpoint(hpc_save_path)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir)
trainer.checkpoint_connector.restore(checkpoint_path)

View File

@ -90,9 +90,13 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True):
pretrained_model(batch)
# test HPC saving
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
# save logger to make sure we get all the metrics
if trainer.logger:
trainer.logger.finalize("finished")
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(ckpt_path)
trainer.save_checkpoint(hpc_save_path)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(ckpt_path)
trainer.checkpoint_connector.restore(checkpoint_path)
if on_gpu:

View File

@ -63,8 +63,12 @@ def test_cpu_slurm_save_load(tmpdir):
# test HPC saving
# simulate snapshot on slurm
saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger)
assert os.path.exists(saved_filepath)
# save logger to make sure we get all the metrics
if logger:
logger.finalize("finished")
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(trainer.weights_save_path)
trainer.save_checkpoint(hpc_save_path)
assert os.path.exists(hpc_save_path)
# new logger file to get meta
logger = tutils.get_default_logger(tmpdir, version=version)

View File

@ -522,7 +522,11 @@ def test_dp_resume(tmpdir):
# HPC LOAD/SAVE
# ---------------------------
# save
trainer.checkpoint_connector.hpc_save(tmpdir, logger)
# save logger to make sure we get all the metrics
if logger:
logger.finalize("finished")
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(tmpdir)
trainer.save_checkpoint(hpc_save_path)
# init new trainer
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)

View File

@ -20,11 +20,13 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from tests.helpers import BoringModel
class HPCHookdedModel(BoringModel):
# TODO: remove HPCHookedModel in v1.8
class HPCHookedModel(BoringModel):
def __init__(self):
super().__init__()
self.hpc_save_called = 0
@ -39,15 +41,25 @@ class HPCHookdedModel(BoringModel):
self.hpc_load_called += 1
def test_hpc_hook_calls(tmpdir):
model = HPCHookdedModel()
# TODO: remove test_hpc_hook_calls in v1.8
@mock.patch(
"pytorch_lightning.trainer.connectors.accelerator_connector.AcceleratorConnector._is_slurm_managing_tasks",
return_value=True,
)
def test_hpc_hook_calls(mock_slurm_env, tmpdir):
model = HPCHookedModel()
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, enable_checkpointing=False, logger=False)
environment = trainer._accelerator_connector.cluster_environment
assert isinstance(environment, SLURMEnvironment)
assert environment.auto_requeue
with pytest.deprecated_call(
match=r"Method `LightningModule.on_hpc_save` is deprecated in v1.6 and will be removed in v1.8."
):
trainer.fit(model)
connector = trainer.checkpoint_connector
connector.hpc_save(tmpdir, logger=Mock())
# simulate snapshot on slurm
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(tmpdir)
trainer.save_checkpoint(hpc_save_path)
assert model.hpc_save_called == 1
assert model.hpc_load_called == 0
@ -134,8 +146,11 @@ def test_hpc_max_ckpt_version(tmpdir):
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")
assert trainer.checkpoint_connector._hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None
assert trainer.checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir) == 33
assert (
trainer.checkpoint_connector._CheckpointConnector__max_ckpt_version_in_folder(tmpdir / "not" / "existing")
is None
)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})