added auto port find
This commit is contained in:
parent
5439dc0844
commit
34ddb0ec98
|
@ -150,7 +150,6 @@ class Trainer(TrainerIO):
|
|||
self.node_rank = 0
|
||||
self.use_ddp = False
|
||||
self.use_dp = False
|
||||
self._ddp_port = None
|
||||
|
||||
# training bookeeping
|
||||
self.total_batch_nb = 0
|
||||
|
@ -398,7 +397,15 @@ class Trainer(TrainerIO):
|
|||
# -----------------------------
|
||||
# MODEL TRAINING
|
||||
# -----------------------------
|
||||
def __kill_ddp_ports(self, port_nb):
|
||||
def __find_open_port(self, port=None):
|
||||
|
||||
if port is None:
|
||||
try:
|
||||
port = os.environ['MASTER_PORT']
|
||||
except Exception as e:
|
||||
port = 12910
|
||||
os.environ['MASTER_PORT'] = f'{port}'
|
||||
|
||||
def get_pids(port):
|
||||
command = "sudo lsof -i :%s | awk '{print $2}'" % port
|
||||
pids = subprocess.check_output(command, shell=True)
|
||||
|
@ -411,19 +418,21 @@ class Trainer(TrainerIO):
|
|||
except:
|
||||
pass
|
||||
|
||||
# kill all processes on this port
|
||||
pids = set(get_pids(port_nb))
|
||||
command = 'sudo kill -9 {}'.format(' '.join([str(pid) for pid in pids]))
|
||||
os.system(command)
|
||||
# get pids in this port
|
||||
pids = set(get_pids(port))
|
||||
|
||||
# if no processes on this port, then we're good
|
||||
if len(pids) == 0:
|
||||
return
|
||||
|
||||
# port wasn't open. Pick a new port and keep trying
|
||||
port = int(port) + 1
|
||||
self.__find_open_port(str(port))
|
||||
|
||||
def fit(self, model):
|
||||
|
||||
# when using multi-node or DDP within a node start each module in a separate process
|
||||
if self.use_ddp:
|
||||
# clear any processes running on the ddp port
|
||||
self.__kill_ddp_ports(self.ddp_port)
|
||||
|
||||
# must copy only the meta of the exp so it survives pickle/unpickle when going to new process
|
||||
self.experiment = self.experiment.get_meta_copy()
|
||||
|
||||
|
@ -561,19 +570,6 @@ class Trainer(TrainerIO):
|
|||
# continue training routine
|
||||
self.__run_pretrain_routine(model)
|
||||
|
||||
@property
|
||||
def ddp_port(self):
|
||||
if self._ddp_port is None:
|
||||
try:
|
||||
port = os.environ['MASTER_PORT']
|
||||
except Exception as e:
|
||||
port = 12910
|
||||
os.environ['MASTER_PORT'] = f'{port}'
|
||||
|
||||
self._ddp_port = port
|
||||
|
||||
return self._ddp_port
|
||||
|
||||
def __init_tcp_connection(self):
|
||||
"""
|
||||
Connect all procs in the world using the env:// init
|
||||
|
@ -583,13 +579,12 @@ class Trainer(TrainerIO):
|
|||
:return:
|
||||
"""
|
||||
# sets the appropriate port
|
||||
_ = self.ddp_port
|
||||
self.__find_open_port()
|
||||
|
||||
root_node = self.__resolve_root_node_address()
|
||||
os.environ['MASTER_ADDR'] = root_node
|
||||
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
|
||||
|
||||
|
||||
def __resolve_root_node_address(self):
|
||||
try:
|
||||
root_node = os.environ['SLURM_NODELIST'].split(' ')[0]
|
||||
|
|
Loading…
Reference in New Issue