added debugging util

This commit is contained in:
William Falcon 2019-07-24 10:47:49 -04:00
parent 0009aa2bcd
commit f41fdc1ad8
1 changed files with 6 additions and 2 deletions

View File

@ -425,7 +425,9 @@ class Trainer(TrainerIO):
If you're not using SLURM, ignore this message! If you're not using SLURM, ignore this message!
""" """
warnings.warn(msg) warnings.warn(msg)
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, )) d = {}
mp.spawn(self.ddp_train, nprocs=len(self.data_parallel_device_ids), args=(model, d))
print(d)
# 1 gpu or dp option triggers training using DP module # 1 gpu or dp option triggers training using DP module
# easier to avoid NCCL issues # easier to avoid NCCL issues
@ -472,7 +474,7 @@ class Trainer(TrainerIO):
self.__run_pretrain_routine(model) self.__run_pretrain_routine(model)
def ddp_train(self, gpu_nb, model): def ddp_train(self, gpu_nb, model, d):
""" """
Entry point into a DP thread Entry point into a DP thread
:param gpu_nb: :param gpu_nb:
@ -482,6 +484,8 @@ class Trainer(TrainerIO):
""" """
# node rank using relative slurm id # node rank using relative slurm id
# otherwise default to node rank 0 # otherwise default to node rank 0
d['helloooo'] = 12.0
try: try:
node_id = os.environ['SLURM_NODEID'] node_id = os.environ['SLURM_NODEID']
self.node_rank = int(node_id) self.node_rank = int(node_id)