added slurm managed flag catch for non-slurm peeps

This commit is contained in:
William Falcon 2019-07-20 08:38:17 -04:00
parent 9757841e67
commit a514674358
1 changed files with 13 additions and 7 deletions

View File

@ -146,6 +146,7 @@ class Trainer(TrainerIO):
self.print_nan_grads = print_nan_grads
self.data_parallel_device_ids = None
self.world_size = 1
self.node_rank = 0
self.use_ddp = False
self.use_dp = False
@ -449,9 +450,9 @@ class Trainer(TrainerIO):
# node rank using relative slurm id
# otherwise default to node rank 0
try:
node_rank = int(os.environ['SLURM_NODEID'])
except KeyError as e:
node_rank = 0
self.node_rank = int(os.environ['SLURM_NODEID'])
except Exception as e:
self.node_rank = 0
# recover original exp before went into process
# init in write mode only on proc 0
@ -459,10 +460,10 @@ class Trainer(TrainerIO):
self.experiment = self.experiment.get_non_ddp_exp()
# show progbar only on prog_rank 0
self.prog_bar = self.prog_bar and node_rank == 0 and gpu_nb == 0
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
# determine which process we are and world size
self.proc_rank = node_rank * len(self.data_parallel_device_ids) + gpu_nb
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
# set up server using proc 0's ip address
@ -470,6 +471,10 @@ class Trainer(TrainerIO):
# where to store ip_table
self.__init_tcp_connection()
print('-'*100)
print(f'INIT COMPLETE')
print('-'*100)
# CHOOSE OPTIMIZER
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
@ -513,8 +518,9 @@ class Trainer(TrainerIO):
root_node = '127.0.0.2'
os.environ['MASTER_ADDR'] = root_node
sleep(self.proc_rank*0.5)
print('-'*100)
print(f'INIT RANK: {self.proc_rank}, NODE:{self.node_rank}, WORLD_SIZE:{self.world_size}, ADDR: {root_node}, PORT: {port}')
print('-'*100)
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
def __run_pretrain_routine(self, model):