diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index f04b1fda61..01f105151a 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -150,7 +150,7 @@ class Trainer(TrainerIO): self.node_rank = 0 self.use_ddp = False self.use_dp = False - self.default_ddp_port = 12910 + self._ddp_port = None # training bookeeping self.total_batch_nb = 0 @@ -399,11 +399,29 @@ class Trainer(TrainerIO): # MODEL TRAINING # ----------------------------- def __kill_ddp_ports(self, port_nb): + def get_pids(port): + command = "sudo lsof -i :%s | awk '{print $2}'" % port + pids = subprocess.check_output(command, shell=True) + pids = pids.strip() + if pids: + pids = re.sub(' +', ' ', pids) + for pid in pids.split('\n'): + try: + yield int(pid) + except: + pass + + # kill all processes on this port + pids = set(get_pids(port_nb)) + command = 'sudo kill -9 {}'.format(' '.join([str(pid) for pid in pids])) + os.system(command) + + def fit(self, model): # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp: - self.__kill_ddp_ports(self.default_ddp_port) + self.__kill_ddp_ports(self.ddp_port) # must copy only the meta of the exp so it survives pickle/unpickle when going to new process self.experiment = self.experiment.get_meta_copy() @@ -542,6 +560,19 @@ class Trainer(TrainerIO): # continue training routine self.__run_pretrain_routine(model) + @property + def ddp_port(self): + if self._ddp_port is None: + try: + port = os.environ['MASTER_PORT'] + except Exception as e: + port = self.default_ddp_port + os.environ['MASTER_PORT'] = f'{port}' + + self._ddp_port = port + + return self._ddp_port + def __init_tcp_connection(self): """ Connect all procs in the world using the env:// init @@ -550,17 +581,11 @@ class Trainer(TrainerIO): :param tries: :return: """ - try: - port = os.environ['MASTER_PORT'] - except Exception as e: - port = self.default_ddp_port - os.environ['MASTER_PORT'] = f'{port}' + # sets the appropriate port + _ = self.ddp_port root_node = self.__resolve_root_node_address() os.environ['MASTER_ADDR'] = root_node - - self.default_ddp_port = port - dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)