updated args
This commit is contained in:
parent
4a3c9de857
commit
12a0e98920
|
@ -340,7 +340,7 @@ class Trainer(TrainerIO):
|
|||
# count items in memory
|
||||
# nb_params, nb_tensors = count_mem_items()
|
||||
|
||||
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
|
||||
|
||||
# add gpu memory
|
||||
if self.on_gpu:
|
||||
|
@ -349,7 +349,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
# add norms
|
||||
if self.track_grad_norm > 0:
|
||||
grad_norm_dic = self.model.grad_norm(self.track_grad_norm)
|
||||
grad_norm_dic = self.model.module.grad_norm(self.track_grad_norm)
|
||||
metrics.update(grad_norm_dic)
|
||||
|
||||
# log metrics
|
||||
|
@ -404,7 +404,6 @@ class Trainer(TrainerIO):
|
|||
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
|
||||
loss = output['loss']
|
||||
|
||||
pdb.set_trace()
|
||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
||||
|
||||
# backward pass
|
||||
|
|
Loading…
Reference in New Issue