added auto port find

This commit is contained in:
William Falcon 2019-07-24 14:45:47 -04:00
parent 5439dc0844
commit 34ddb0ec98
1 changed files with 19 additions and 24 deletions

View File

@ -150,7 +150,6 @@ class Trainer(TrainerIO):
self.node_rank = 0
self.use_ddp = False
self.use_dp = False
self._ddp_port = None
# training bookeeping
self.total_batch_nb = 0
@ -398,7 +397,15 @@ class Trainer(TrainerIO):
# -----------------------------
# MODEL TRAINING
# -----------------------------
def __kill_ddp_ports(self, port_nb):
def __find_open_port(self, port=None):
if port is None:
try:
port = os.environ['MASTER_PORT']
except Exception as e:
port = 12910
os.environ['MASTER_PORT'] = f'{port}'
def get_pids(port):
command = "sudo lsof -i :%s | awk '{print $2}'" % port
pids = subprocess.check_output(command, shell=True)
@ -411,19 +418,21 @@ class Trainer(TrainerIO):
except:
pass
# kill all processes on this port
pids = set(get_pids(port_nb))
command = 'sudo kill -9 {}'.format(' '.join([str(pid) for pid in pids]))
os.system(command)
# get pids in this port
pids = set(get_pids(port))
# if no processes on this port, then we're good
if len(pids) == 0:
return
# port wasn't open. Pick a new port and keep trying
port = int(port) + 1
self.__find_open_port(str(port))
def fit(self, model):
# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp:
# clear any processes running on the ddp port
self.__kill_ddp_ports(self.ddp_port)
# must copy only the meta of the exp so it survives pickle/unpickle when going to new process
self.experiment = self.experiment.get_meta_copy()
@ -561,19 +570,6 @@ class Trainer(TrainerIO):
# continue training routine
self.__run_pretrain_routine(model)
@property
def ddp_port(self):
if self._ddp_port is None:
try:
port = os.environ['MASTER_PORT']
except Exception as e:
port = 12910
os.environ['MASTER_PORT'] = f'{port}'
self._ddp_port = port
return self._ddp_port
def __init_tcp_connection(self):
"""
Connect all procs in the world using the env:// init
@ -583,13 +579,12 @@ class Trainer(TrainerIO):
:return:
"""
# sets the appropriate port
_ = self.ddp_port
self.__find_open_port()
root_node = self.__resolve_root_node_address()
os.environ['MASTER_ADDR'] = root_node
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
def __resolve_root_node_address(self):
try:
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]