From 5a3007cd6c417cb24f7114060e83cc63f64f648c Mon Sep 17 00:00:00 2001 From: Max Ehrlich Date: Mon, 10 Oct 2022 07:43:57 -0400 Subject: [PATCH] Support Slurm Autorequeue for Array Jobs (#15040) Signed-off-by: Max Ehrlich Co-authored-by: awaelchli --- src/pytorch_lightning/CHANGELOG.md | 3 +++ .../trainer/connectors/signal_connector.py | 8 +++++- .../connectors/test_signal_connector.py | 26 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 9b91171d37..36a115cc7c 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for requeueing slurm array jobs ([#15022](https://github.com/Lightning-AI/lightning/issues/15022)) + + - Added native AMP support for `ddp_fork` (and associated alias strategies) with CUDA GPUs ([#14983](https://github.com/Lightning-AI/lightning/pull/14983)) diff --git a/src/pytorch_lightning/trainer/connectors/signal_connector.py b/src/pytorch_lightning/trainer/connectors/signal_connector.py index 83a9e38ce0..540155e2dc 100644 --- a/src/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/src/pytorch_lightning/trainer/connectors/signal_connector.py @@ -78,7 +78,13 @@ class SignalConnector: if self.trainer.is_global_zero: # find job id - job_id = os.environ["SLURM_JOB_ID"] + array_job_id = os.getenv("SLURM_ARRAY_JOB_ID") + if array_job_id is not None: + array_task_id = os.environ["SLURM_ARRAY_TASK_ID"] + job_id = f"{array_job_id}_{array_task_id}" + else: + job_id = os.environ["SLURM_JOB_ID"] + cmd = ["scontrol", "requeue", job_id] # requeue job diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 21b9364d2e..a35f5f28dc 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -100,6 +100,32 @@ def test_auto_requeue_custom_signal_flag(auto_requeue, requeue_signal): connector.teardown() +@RunIf(skip_windows=True) +@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call") +@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock()) +@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12345"}) +def test_auto_requeue_job(call_mock): + call_mock.return_value = 0 + trainer = Trainer(plugins=[SLURMEnvironment()]) + connector = SignalConnector(trainer) + connector.slurm_sigusr_handler_fn(None, None) + call_mock.assert_called_once_with(["scontrol", "requeue", "12345"]) + connector.teardown() + + +@RunIf(skip_windows=True) +@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call") +@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock()) +@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12346", "SLURM_ARRAY_JOB_ID": "12345", "SLURM_ARRAY_TASK_ID": "2"}) +def test_auto_requeue_array_job(call_mock): + call_mock.return_value = 0 + trainer = Trainer(plugins=[SLURMEnvironment()]) + connector = SignalConnector(trainer) + connector.slurm_sigusr_handler_fn(None, None) + call_mock.assert_called_once_with(["scontrol", "requeue", "12345_2"]) + connector.teardown() + + def _registering_signals(): trainer = Trainer() trainer._signal_connector.register_signal_handlers()