auto port kill before starting ddp
This commit is contained in:
parent
1ae91aac32
commit
8f06118154
|
@ -150,7 +150,7 @@ class Trainer(TrainerIO):
|
|||
self.node_rank = 0
|
||||
self.use_ddp = False
|
||||
self.use_dp = False
|
||||
self.default_ddp_port = 12910
|
||||
self._ddp_port = None
|
||||
|
||||
# training bookeeping
|
||||
self.total_batch_nb = 0
|
||||
|
@ -399,11 +399,29 @@ class Trainer(TrainerIO):
|
|||
# MODEL TRAINING
|
||||
# -----------------------------
|
||||
def __kill_ddp_ports(self, port_nb):
|
||||
def get_pids(port):
|
||||
command = "sudo lsof -i :%s | awk '{print $2}'" % port
|
||||
pids = subprocess.check_output(command, shell=True)
|
||||
pids = pids.strip()
|
||||
if pids:
|
||||
pids = re.sub(' +', ' ', pids)
|
||||
for pid in pids.split('\n'):
|
||||
try:
|
||||
yield int(pid)
|
||||
except:
|
||||
pass
|
||||
|
||||
# kill all processes on this port
|
||||
pids = set(get_pids(port_nb))
|
||||
command = 'sudo kill -9 {}'.format(' '.join([str(pid) for pid in pids]))
|
||||
os.system(command)
|
||||
|
||||
|
||||
def fit(self, model):
|
||||
|
||||
# when using multi-node or DDP within a node start each module in a separate process
|
||||
if self.use_ddp:
|
||||
self.__kill_ddp_ports(self.default_ddp_port)
|
||||
self.__kill_ddp_ports(self.ddp_port)
|
||||
|
||||
# must copy only the meta of the exp so it survives pickle/unpickle when going to new process
|
||||
self.experiment = self.experiment.get_meta_copy()
|
||||
|
@ -542,6 +560,19 @@ class Trainer(TrainerIO):
|
|||
# continue training routine
|
||||
self.__run_pretrain_routine(model)
|
||||
|
||||
@property
|
||||
def ddp_port(self):
|
||||
if self._ddp_port is None:
|
||||
try:
|
||||
port = os.environ['MASTER_PORT']
|
||||
except Exception as e:
|
||||
port = self.default_ddp_port
|
||||
os.environ['MASTER_PORT'] = f'{port}'
|
||||
|
||||
self._ddp_port = port
|
||||
|
||||
return self._ddp_port
|
||||
|
||||
def __init_tcp_connection(self):
|
||||
"""
|
||||
Connect all procs in the world using the env:// init
|
||||
|
@ -550,17 +581,11 @@ class Trainer(TrainerIO):
|
|||
:param tries:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
port = os.environ['MASTER_PORT']
|
||||
except Exception as e:
|
||||
port = self.default_ddp_port
|
||||
os.environ['MASTER_PORT'] = f'{port}'
|
||||
# sets the appropriate port
|
||||
_ = self.ddp_port
|
||||
|
||||
root_node = self.__resolve_root_node_address()
|
||||
os.environ['MASTER_ADDR'] = root_node
|
||||
|
||||
self.default_ddp_port = port
|
||||
|
||||
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue