From 0203938af8f69a19b7e0264f18e03d543d86e0e9 Mon Sep 17 00:00:00 2001 From: Roshan Rao Date: Mon, 20 Apr 2020 05:03:52 -0700 Subject: [PATCH] Update learning rate on each backward pass instead of each forward pass. (#1477) * change lr scheduler step interval to update every backwards pass instead of every forwards pass * update CHANGELOG * fix spacing * Add TODO to lr schedule update * remove trailing whitespace Co-authored-by: William Falcon --- CHANGELOG.md | 10 +++++----- pytorch_lightning/trainer/training_loop.py | 7 +++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0d8bdfc84..e3d6110c52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,20 +24,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) +- Changed lr schedule step interval behavior to update every backwards pass instead of every forwards pass ([#1476](https://github.com/PyTorchLightning/pytorch-lightning/issues/1476)) + - Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371)) -- ### Deprecated -- +- ### Removed -- +- -- +- ### Fixed @@ -52,7 +53,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug that caused the `callbacks` Trainer argument to reference a global variable ([#1534](https://github.com/PyTorchLightning/pytorch-lightning/pull/1534)). - ## [0.7.3] - 2020-04-09 ### Added diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 538d717afc..ed0604c740 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -454,8 +454,11 @@ class TrainerTrainLoopMixin(ABC): # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 - # update lr - self.update_learning_rates(interval='step') + # TODO: consolidate all actions that need to take place only after + # self.accumulate_grad_batches steps (optimizer step, lr update, global step increment) + if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: + # update lr + self.update_learning_rates(interval='step') # --------------- # RUN VAL STEP