diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ececc7b317..e88750ad25 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -992,7 +992,7 @@ class Trainer(TrainerIO): # --------------- # RUN TRAIN STEP # --------------- - batch_result = self.__run_training_batch(batch, batch_nb) + batch_result, grad_norm_dic = self.__run_training_batch(batch, batch_nb) early_stop_epoch = batch_result == -1 # --------------- @@ -1023,10 +1023,7 @@ class Trainer(TrainerIO): metrics.update(mem_map) # add norms - if self.track_grad_norm > 0: - model = self.__get_model() - grad_norm_dic = model.grad_norm(self.track_grad_norm) - metrics.update(grad_norm_dic) + metrics.update(grad_norm_dic) if self.__is_function_implemented('on_training_metrics'): model.on_training_metrics(metrics) @@ -1178,8 +1175,11 @@ class Trainer(TrainerIO): print(param, param.grad) def __run_training_batch(self, batch, batch_nb): + # track grad norms + grad_norm_dic = {} + if batch is None: - return 0 + return 0, grad_norm_dic # hook if self.__is_function_implemented('on_batch_start'): @@ -1187,7 +1187,7 @@ class Trainer(TrainerIO): response = model_ref.on_batch_start(batch) if response == -1: - return -1 + return -1, grad_norm_dic if self.show_progress_bar: self.progress_bar.update(1) @@ -1226,6 +1226,13 @@ class Trainer(TrainerIO): # gradient update with accumulated gradients if (self.batch_nb + 1) % self.accumulate_grad_batches == 0: + + # track gradient norms when requested + if batch_nb % self.row_log_interval == 0: + if self.track_grad_norm > 0: + model = self.__get_model() + grad_norm_dic = model.grad_norm(self.track_grad_norm) + # clip gradients self.__clip_gradients() @@ -1250,7 +1257,7 @@ class Trainer(TrainerIO): model = self.__get_model() model.on_batch_end() - return 0 + return 0, grad_norm_dic def __run_evaluation(self, test=False): # when testing make sure user defined a test step