added slurm managed flag catch for non-slurm peeps
This commit is contained in:
parent
9757841e67
commit
a514674358
|
@ -146,6 +146,7 @@ class Trainer(TrainerIO):
|
|||
self.print_nan_grads = print_nan_grads
|
||||
self.data_parallel_device_ids = None
|
||||
self.world_size = 1
|
||||
self.node_rank = 0
|
||||
self.use_ddp = False
|
||||
self.use_dp = False
|
||||
|
||||
|
@ -449,9 +450,9 @@ class Trainer(TrainerIO):
|
|||
# node rank using relative slurm id
|
||||
# otherwise default to node rank 0
|
||||
try:
|
||||
node_rank = int(os.environ['SLURM_NODEID'])
|
||||
except KeyError as e:
|
||||
node_rank = 0
|
||||
self.node_rank = int(os.environ['SLURM_NODEID'])
|
||||
except Exception as e:
|
||||
self.node_rank = 0
|
||||
|
||||
# recover original exp before went into process
|
||||
# init in write mode only on proc 0
|
||||
|
@ -459,10 +460,10 @@ class Trainer(TrainerIO):
|
|||
self.experiment = self.experiment.get_non_ddp_exp()
|
||||
|
||||
# show progbar only on prog_rank 0
|
||||
self.prog_bar = self.prog_bar and node_rank == 0 and gpu_nb == 0
|
||||
self.prog_bar = self.prog_bar and self.node_rank == 0 and gpu_nb == 0
|
||||
|
||||
# determine which process we are and world size
|
||||
self.proc_rank = node_rank * len(self.data_parallel_device_ids) + gpu_nb
|
||||
self.proc_rank = self.node_rank * len(self.data_parallel_device_ids) + gpu_nb
|
||||
self.world_size = self.nb_gpu_nodes * len(self.data_parallel_device_ids)
|
||||
|
||||
# set up server using proc 0's ip address
|
||||
|
@ -470,6 +471,10 @@ class Trainer(TrainerIO):
|
|||
# where to store ip_table
|
||||
self.__init_tcp_connection()
|
||||
|
||||
print('-'*100)
|
||||
print(f'INIT COMPLETE')
|
||||
print('-'*100)
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
self.optimizers = model.configure_optimizers()
|
||||
|
@ -513,8 +518,9 @@ class Trainer(TrainerIO):
|
|||
root_node = '127.0.0.2'
|
||||
|
||||
os.environ['MASTER_ADDR'] = root_node
|
||||
|
||||
sleep(self.proc_rank*0.5)
|
||||
print('-'*100)
|
||||
print(f'INIT RANK: {self.proc_rank}, NODE:{self.node_rank}, WORLD_SIZE:{self.world_size}, ADDR: {root_node}, PORT: {port}')
|
||||
print('-'*100)
|
||||
dist.init_process_group("nccl", rank=self.proc_rank, world_size=self.world_size)
|
||||
|
||||
def __run_pretrain_routine(self, model):
|
||||
|
|
Loading…
Reference in New Issue