From 34ddb0ec98b847dee8438fce73c87175005e4011 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 14:45:47 -0400 Subject: [PATCH] added auto port find --- pytorch_lightning/models/trainer.py | 43 +++++++++++++---------------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 1687093116..55fd27ddce 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -150,7 +150,6 @@ class Trainer(TrainerIO): self.node_rank = 0 self.use_ddp = False self.use_dp = False - self._ddp_port = None # training bookeeping self.total_batch_nb = 0 @@ -398,7 +397,15 @@ class Trainer(TrainerIO): # ----------------------------- # MODEL TRAINING # ----------------------------- - def __kill_ddp_ports(self, port_nb): + def __find_open_port(self, port=None): + + if port is None: + try: + port = os.environ['MASTER_PORT'] + except Exception as e: + port = 12910 + os.environ['MASTER_PORT'] = f'{port}' + def get_pids(port): command = "sudo lsof -i :%s | awk '{print $2}'" % port pids = subprocess.check_output(command, shell=True) @@ -411,19 +418,21 @@ class Trainer(TrainerIO): except: pass - # kill all processes on this port - pids = set(get_pids(port_nb)) - command = 'sudo kill -9 {}'.format(' '.join([str(pid) for pid in pids])) - os.system(command) + # get pids in this port + pids = set(get_pids(port)) + # if no processes on this port, then we're good + if len(pids) == 0: + return + + # port wasn't open. Pick a new port and keep trying + port = int(port) + 1 + self.__find_open_port(str(port)) def fit(self, model): # when using multi-node or DDP within a node start each module in a separate process if self.use_ddp: - # clear any processes running on the ddp port - self.__kill_ddp_ports(self.ddp_port) - # must copy only the meta of the exp so it survives pickle/unpickle when going to new process self.experiment = self.experiment.get_meta_copy() @@ -561,19 +570,6 @@ class Trainer(TrainerIO): # continue training routine self.__run_pretrain_routine(model) - @property - def ddp_port(self): - if self._ddp_port is None: - try: - port = os.environ['MASTER_PORT'] - except Exception as e: - port = 12910 - os.environ['MASTER_PORT'] = f'{port}' - - self._ddp_port = port - - return self._ddp_port - def __init_tcp_connection(self): """ Connect all procs in the world using the env:// init @@ -583,13 +579,12 @@ class Trainer(TrainerIO): :return: """ # sets the appropriate port - _ = self.ddp_port + self.__find_open_port() root_node = self.__resolve_root_node_address() os.environ['MASTER_ADDR'] = root_node dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size) - def __resolve_root_node_address(self): try: root_node = os.environ['SLURM_NODELIST'].split(' ')[0]