From b26c71eadf11d3a6aa9428504b8e646e0a8a03f8 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 28 Oct 2020 17:45:22 +0530 Subject: [PATCH] Add optimizer hooks in callbacks (#4379) * Add optimizer hooks in callbacks * optimizer param * update test Co-authored-by: Nicki Skafte --- pytorch_lightning/callbacks/base.py | 12 ++++++++++++ pytorch_lightning/trainer/callback_hook.py | 14 ++++++++++++++ pytorch_lightning/trainer/training_loop.py | 3 +-- tests/callbacks/test_callbacks.py | 18 ++++++++++++++++++ 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 591703c245..004aa6d737 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -166,3 +166,15 @@ class Callback(abc.ABC): def on_load_checkpoint(self, checkpointed_state): """Called when loading a model checkpoint, use to reload state.""" pass + + def on_after_backward(self, trainer, pl_module): + """ + Called after loss.backward() and before optimizers do anything. + """ + pass + + def on_before_zero_grad(self, trainer, pl_module, optimizer): + """ + Called after optimizer.step() and before optimizer.zero_grad(). + """ + pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 46f2a32c0a..8f3885c20f 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -209,3 +209,17 @@ class TrainerCallbackHookMixin(ABC): if state: state = deepcopy(state) callback.on_load_checkpoint(state) + + def on_after_backward(self): + """ + Called after loss.backward() and before optimizers do anything. + """ + for callback in self.callbacks: + callback.on_after_backward(self, self.get_model()) + + def on_before_zero_grad(self, optimizer): + """ + Called after optimizer.step() and before optimizer.zero_grad(). + """ + for callback in self.callbacks: + callback.on_before_zero_grad(self, self.get_model(), optimizer) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d32f47dbbd..934938c63f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -454,8 +454,7 @@ class TrainLoop: ) def on_before_zero_grad(self, optimizer): - model = self.trainer.get_model() - model.on_before_zero_grad(optimizer) + self.trainer.call_hook('on_before_zero_grad', optimizer) def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index bb7ec8430a..cf88f52436 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -55,6 +55,8 @@ def test_trainer_callback_system(tmpdir): self.on_validation_end_called = False self.on_test_start_called = False self.on_test_end_called = False + self.on_after_backward_called = False + self.on_before_zero_grad_called = False def setup(self, trainer, pl_module, stage: str): assert isinstance(trainer, Trainer) @@ -160,6 +162,14 @@ def test_trainer_callback_system(tmpdir): _check_args(trainer, pl_module) self.on_test_end_called = True + def on_after_backward(self, trainer, pl_module): + _check_args(trainer, pl_module) + self.on_after_backward_called = True + + def on_before_zero_grad(self, trainer, pl_module, optimizer): + _check_args(trainer, pl_module) + self.on_before_zero_grad_called = True + test_callback = TestCallback() trainer_options = dict( @@ -197,6 +207,8 @@ def test_trainer_callback_system(tmpdir): assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called + assert not test_callback.on_after_backward_called + assert not test_callback.on_before_zero_grad_called # fit model trainer = Trainer(**trainer_options) @@ -228,6 +240,8 @@ def test_trainer_callback_system(tmpdir): assert not test_callback.on_validation_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called + assert not test_callback.on_after_backward_called + assert not test_callback.on_before_zero_grad_called trainer.fit(model) @@ -257,6 +271,8 @@ def test_trainer_callback_system(tmpdir): assert not test_callback.on_test_batch_end_called assert not test_callback.on_test_start_called assert not test_callback.on_test_end_called + assert test_callback.on_after_backward_called + assert test_callback.on_before_zero_grad_called # reset setup teardown callback test_callback.teardown_called = False @@ -277,3 +293,5 @@ def test_trainer_callback_system(tmpdir): assert not test_callback.on_validation_end_called assert not test_callback.on_validation_batch_end_called assert not test_callback.on_validation_batch_start_called + assert not test_callback.on_after_backward_called + assert not test_callback.on_before_zero_grad_called