Support Slurm Autorequeue for Array Jobs (#15040)
Signed-off-by: Max Ehrlich <max.ehr@gmail.com> Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
ddfcddbd1c
commit
5a3007cd6c
|
@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added support for requeueing slurm array jobs ([#15022](https://github.com/Lightning-AI/lightning/issues/15022))
|
||||
|
||||
|
||||
- Added native AMP support for `ddp_fork` (and associated alias strategies) with CUDA GPUs ([#14983](https://github.com/Lightning-AI/lightning/pull/14983))
|
||||
|
||||
|
||||
|
|
|
@ -78,7 +78,13 @@ class SignalConnector:
|
|||
|
||||
if self.trainer.is_global_zero:
|
||||
# find job id
|
||||
job_id = os.environ["SLURM_JOB_ID"]
|
||||
array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
|
||||
if array_job_id is not None:
|
||||
array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
|
||||
job_id = f"{array_job_id}_{array_task_id}"
|
||||
else:
|
||||
job_id = os.environ["SLURM_JOB_ID"]
|
||||
|
||||
cmd = ["scontrol", "requeue", job_id]
|
||||
|
||||
# requeue job
|
||||
|
|
|
@ -100,6 +100,32 @@ def test_auto_requeue_custom_signal_flag(auto_requeue, requeue_signal):
|
|||
connector.teardown()
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call")
|
||||
@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock())
|
||||
@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12345"})
|
||||
def test_auto_requeue_job(call_mock):
|
||||
call_mock.return_value = 0
|
||||
trainer = Trainer(plugins=[SLURMEnvironment()])
|
||||
connector = SignalConnector(trainer)
|
||||
connector.slurm_sigusr_handler_fn(None, None)
|
||||
call_mock.assert_called_once_with(["scontrol", "requeue", "12345"])
|
||||
connector.teardown()
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call")
|
||||
@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock())
|
||||
@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12346", "SLURM_ARRAY_JOB_ID": "12345", "SLURM_ARRAY_TASK_ID": "2"})
|
||||
def test_auto_requeue_array_job(call_mock):
|
||||
call_mock.return_value = 0
|
||||
trainer = Trainer(plugins=[SLURMEnvironment()])
|
||||
connector = SignalConnector(trainer)
|
||||
connector.slurm_sigusr_handler_fn(None, None)
|
||||
call_mock.assert_called_once_with(["scontrol", "requeue", "12345_2"])
|
||||
connector.teardown()
|
||||
|
||||
|
||||
def _registering_signals():
|
||||
trainer = Trainer()
|
||||
trainer._signal_connector.register_signal_handlers()
|
||||
|
|
Loading…
Reference in New Issue