62 lines
2.0 KiB
Python
62 lines
2.0 KiB
Python
import logging
|
|
import os
|
|
import signal
|
|
from subprocess import call
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class SLURMConnector:
|
|
|
|
def __init__(self, trainer):
|
|
self.trainer = trainer
|
|
|
|
def register_slurm_signal_handlers(self):
|
|
# see if we're using slurm (not interactive)
|
|
on_slurm = False
|
|
try:
|
|
job_name = os.environ['SLURM_JOB_NAME']
|
|
if job_name != 'bash':
|
|
on_slurm = True
|
|
# todo: specify the possible exception
|
|
except Exception:
|
|
pass
|
|
|
|
if on_slurm:
|
|
log.info('Set SLURM handle signals.')
|
|
signal.signal(signal.SIGUSR1, self.sig_handler)
|
|
signal.signal(signal.SIGTERM, self.term_handler)
|
|
|
|
def sig_handler(self, signum, frame): # pragma: no-cover
|
|
if self.trainer.is_global_zero:
|
|
# save weights
|
|
log.info('handling SIGUSR1')
|
|
self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger)
|
|
|
|
# find job id
|
|
job_id = os.environ['SLURM_JOB_ID']
|
|
cmd = ['scontrol', 'requeue', job_id]
|
|
|
|
# requeue job
|
|
log.info(f'requeing job {job_id}...')
|
|
try:
|
|
result = call(cmd)
|
|
except FileNotFoundError:
|
|
# This can occur if a subprocess call to `scontrol` is run outside a shell context
|
|
# Re-attempt call (now with shell context). If any error is raised, propagate to user.
|
|
# When running a shell command, it should be passed as a single string.
|
|
joint_cmd = [str(x) for x in cmd]
|
|
result = call(' '.join(joint_cmd), shell=True)
|
|
|
|
# print result text
|
|
if result == 0:
|
|
log.info(f'requeued exp {job_id}')
|
|
else:
|
|
log.warning('requeue failed...')
|
|
|
|
# close experiment to avoid issues
|
|
self.trainer.logger.close()
|
|
|
|
def term_handler(self, signum, frame): # pragma: no-cover
|
|
log.info("bypassing sigterm")
|