auto port kill before starting ddp

This commit is contained in:
William Falcon 2019-07-24 14:36:29 -04:00
parent 1ae91aac32
commit 8f06118154
1 changed files with 35 additions and 10 deletions

View File

@ -150,7 +150,7 @@ class Trainer(TrainerIO):
self.node_rank = 0
self.use_ddp = False
self.use_dp = False
self.default_ddp_port = 12910
self._ddp_port = None
# training bookeeping
self.total_batch_nb = 0
@ -399,11 +399,29 @@ class Trainer(TrainerIO):
# MODEL TRAINING
# -----------------------------
def __kill_ddp_ports(self, port_nb):
def get_pids(port):
command = "sudo lsof -i :%s | awk '{print $2}'" % port
pids = subprocess.check_output(command, shell=True)
pids = pids.strip()
if pids:
pids = re.sub(' +', ' ', pids)
for pid in pids.split('\n'):
try:
yield int(pid)
except:
pass
# kill all processes on this port
pids = set(get_pids(port_nb))
command = 'sudo kill -9 {}'.format(' '.join([str(pid) for pid in pids]))
os.system(command)
def fit(self, model):
# when using multi-node or DDP within a node start each module in a separate process
if self.use_ddp:
self.__kill_ddp_ports(self.default_ddp_port)
self.__kill_ddp_ports(self.ddp_port)
# must copy only the meta of the exp so it survives pickle/unpickle when going to new process
self.experiment = self.experiment.get_meta_copy()
@ -542,6 +560,19 @@ class Trainer(TrainerIO):
# continue training routine
self.__run_pretrain_routine(model)
@property
def ddp_port(self):
if self._ddp_port is None:
try:
port = os.environ['MASTER_PORT']
except Exception as e:
port = self.default_ddp_port
os.environ['MASTER_PORT'] = f'{port}'
self._ddp_port = port
return self._ddp_port
def __init_tcp_connection(self):
"""
Connect all procs in the world using the env:// init
@ -550,17 +581,11 @@ class Trainer(TrainerIO):
:param tries:
:return:
"""
try:
port = os.environ['MASTER_PORT']
except Exception as e:
port = self.default_ddp_port
os.environ['MASTER_PORT'] = f'{port}'
# sets the appropriate port
_ = self.ddp_port
root_node = self.__resolve_root_node_address()
os.environ['MASTER_ADDR'] = root_node
self.default_ddp_port = port
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)