diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index ab8dbdb47b..d206f0931a 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -371,7 +371,7 @@ class Trainer(TrainerIO): # continue training routine self.__run_pretrain_routine(model) - def __init_tcp_connection(self, port=12945): + def __init_tcp_connection(self): """ Connect all procs in the world using the env:// init Use the first node as the root address @@ -379,6 +379,10 @@ class Trainer(TrainerIO): :param tries: :return: """ + try: + port = os.environ['MASTER_PORT'] + except Exception as e: + port = 12910 root_node = os.environ['SLURM_NODELIST'].split(' ')[0] os.environ['MASTER_ADDR'] = root_node