added debugging util
This commit is contained in:
parent
0009aa2bcd
commit
f41fdc1ad8
|
@ -425,7 +425,9 @@ class Trainer(TrainerIO):
|
|||
If you're not using SLURM, ignore this message!
|
||||
"""
|
||||
warnings.warn(msg)
|
||||
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, ))
|
||||
d = {}
|
||||
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, d))
|
||||
print(d)
|
||||
|
||||
# 1 gpu or dp option triggers training using DP module
|
||||
# easier to avoid NCCL issues
|
||||
|
@ -472,7 +474,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
self.__run_pretrain_routine(model)
|
||||
|
||||
def ddp_train(self, gpu_nb, model):
|
||||
def ddp_train(self, gpu_nb, model, d):
|
||||
"""
|
||||
Entry point into a DP thread
|
||||
:param gpu_nb:
|
||||
|
@ -482,6 +484,8 @@ class Trainer(TrainerIO):
|
|||
"""
|
||||
# node rank using relative slurm id
|
||||
# otherwise default to node rank 0
|
||||
d['helloooo'] = 12.0
|
||||
|
||||
try:
|
||||
node_id = os.environ['SLURM_NODEID']
|
||||
self.node_rank = int(node_id)
|
||||
|
|
Loading…
Reference in New Issue