diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 1c646a1221..0f76512041 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -899,6 +899,9 @@ We recommend you switch to ddp if you want to use amp self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) + # accumulate loss (if accumulate_grad_batches = 1 no effect) + loss = loss / self.accumulate_grad_batches + # backward pass if self.use_amp: # scale loss when using amp @@ -918,12 +921,11 @@ We recommend you switch to ddp if you want to use amp for param in model.parameters(): print(param.grad.float().sum()) - # avoid memory leaks + # track total loss for logging (avoid mem leaks) self.batch_loss_value += loss.item() # gradient update with accumulated gradients if (self.batch_nb + 1) % self.accumulate_grad_batches == 0: - # clip gradients if self.gradient_clip > 0: model = self.__get_model() @@ -941,11 +943,7 @@ We recommend you switch to ddp if you want to use amp # clear gradients optimizer.zero_grad() - # queuing loss across batches blows it up proportionally... - # divide out the number accumulated - self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches - - # track loss + # calculate running loss for display self.running_loss.append(self.batch_loss_value) self.batch_loss_value = 0 self.avg_loss = np.mean(self.running_loss[-100:])