diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index a1d63d2c6b..1bdd165496 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -374,9 +374,23 @@ class Trainer(TrainerIO): if self.use_ddp: # must copy only the meta of the exp so it survives pickle/unpickle when going to new process self.experiment = self.experiment.get_meta_copy() - task = int(os.environ['SLURM_LOCALID']) - self.ddp_train(task, model) - # mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, )) + + # whenever we have the correct number of tasks, we let slurm manage processes + # otherwise we launch the required number of processes + nb_slurm_tasks = int(os.environ['SLURM_NTASKS']) + nb_requested_gpus = len(self.data_parallel_device_ids) + is_slurm_managing_tasks = nb_slurm_tasks == nb_requested_gpus + if is_slurm_managing_tasks: + task = int(os.environ['SLURM_LOCALID']) + self.ddp_train(task, model) + else: + msg = f""" + You requested {nb_requested_gpus} GPUs but launched {nb_slurm_tasks} slurm tasks. + We will launch {nb_requested_gpus} processes for you. + We recommend you let slurm manage the processes by setting: --ntasks-per-node={nb_requested_gpus} + """ + warnings.warn(msg) + mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, )) # 1 gpu or dp option triggers training using DP module # easier to avoid NCCL issues