WIP: Moved grad_norm tracking code to __run_tng_batch (#278)

* Moved grad_norm tracking code to __run_tng_batch + added norms to tqdm_metrics

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py
This commit is contained in:
kvhooreb 2019-10-02 17:11:08 +02:00 committed by William Falcon
parent 614cb3c03b
commit 41236c7bbb
1 changed files with 15 additions and 8 deletions

View File

@ -992,7 +992,7 @@ class Trainer(TrainerIO):
# ---------------
# RUN TRAIN STEP
# ---------------
batch_result = self.__run_training_batch(batch, batch_nb)
batch_result, grad_norm_dic = self.__run_training_batch(batch, batch_nb)
early_stop_epoch = batch_result == -1
# ---------------
@ -1023,10 +1023,7 @@ class Trainer(TrainerIO):
metrics.update(mem_map)
# add norms
if self.track_grad_norm > 0:
model = self.__get_model()
grad_norm_dic = model.grad_norm(self.track_grad_norm)
metrics.update(grad_norm_dic)
metrics.update(grad_norm_dic)
if self.__is_function_implemented('on_training_metrics'):
model.on_training_metrics(metrics)
@ -1178,8 +1175,11 @@ class Trainer(TrainerIO):
print(param, param.grad)
def __run_training_batch(self, batch, batch_nb):
# track grad norms
grad_norm_dic = {}
if batch is None:
return 0
return 0, grad_norm_dic
# hook
if self.__is_function_implemented('on_batch_start'):
@ -1187,7 +1187,7 @@ class Trainer(TrainerIO):
response = model_ref.on_batch_start(batch)
if response == -1:
return -1
return -1, grad_norm_dic
if self.show_progress_bar:
self.progress_bar.update(1)
@ -1226,6 +1226,13 @@ class Trainer(TrainerIO):
# gradient update with accumulated gradients
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
# track gradient norms when requested
if batch_nb % self.row_log_interval == 0:
if self.track_grad_norm > 0:
model = self.__get_model()
grad_norm_dic = model.grad_norm(self.track_grad_norm)
# clip gradients
self.__clip_gradients()
@ -1250,7 +1257,7 @@ class Trainer(TrainerIO):
model = self.__get_model()
model.on_batch_end()
return 0
return 0, grad_norm_dic
def __run_evaluation(self, test=False):
# when testing make sure user defined a test step