diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index fa219ba14e..83ea60daab 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -343,11 +343,16 @@ class Trainer(TrainerIO): # this is based on its global rank # it communicates its ip by saving an ip_table to the slurm cluster logging dir # every other process waits for this ip to appear before continuing - ip_table_name = f'.ip_meta_' + os.environ['SLURM_JOB_ID'] + ip_table_name = f'ip_meta_' + os.environ['SLURM_JOB_ID'] ip_file = os.path.join(ip_file_dir, ip_table_name) os.makedirs(ip_file_dir, exist_ok=True) - dist.init_process_group("nccl", init_method=f'file://{ip_file}', rank=self.proc_rank, + root_node = os.environ['SLURM_NODELIST'].split(' ')[0] + print('-'*100) + print('SLURM ROOT NODE: ', root_node) + print('-'*100) + + dist.init_process_group("nccl", init_method=f'env://{root_node}', rank=self.proc_rank, world_size=self.world_size) # self.__init_tcp_connection(ip_file_dir)