diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 1606d99a5a..58f6b71dcc 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -383,8 +383,8 @@ class Trainer(TrainerIO): root_node = os.environ['SLURM_NODELIST'].split(' ')[0] os.environ['MASTER_ADDR'] = root_node os.environ['MASTER_PORT'] = f'{port}' - # dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size) - dist.init_process_group("nccl") + dist.init_process_group("nccl", rank=self.proc_rank) + # dist.init_process_group("nccl") def __run_pretrain_routine(self, model): """