diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c013db5b48..ddab8f837b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -246,7 +246,7 @@ class TrainLoop: opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)) return [(opt_idx, self.trainer.optimizers[opt_idx])] - def on_after_backward(self, training_step_output, batch_idx, untouched_loss): + def on_after_backward(self, batch_idx, untouched_loss): # insert after step hook self.trainer.call_hook("on_after_backward") @@ -760,7 +760,7 @@ class TrainLoop: # hook - call this hook only # when gradients have finished to accumulate if not self.should_accumulate(): - self.on_after_backward(result.training_step_output, batch_idx, result.loss) + self.on_after_backward(batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: