updated args

This commit is contained in:
William Falcon 2019-06-26 17:54:59 -04:00
parent 4a3c9de857
commit 12a0e98920
1 changed files with 2 additions and 3 deletions

View File

@ -340,7 +340,7 @@ class Trainer(TrainerIO):
# count items in memory
# nb_params, nb_tensors = count_mem_items()
metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic)
metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic)
# add gpu memory
if self.on_gpu:
@ -349,7 +349,7 @@ class Trainer(TrainerIO):
# add norms
if self.track_grad_norm > 0:
grad_norm_dic = self.model.grad_norm(self.track_grad_norm)
grad_norm_dic = self.model.module.grad_norm(self.track_grad_norm)
metrics.update(grad_norm_dic)
# log metrics
@ -404,7 +404,6 @@ class Trainer(TrainerIO):
model_specific_tqdm_metrics_dic = output['tqdm_metrics']
loss = output['loss']
pdb.set_trace()
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# backward pass