Add optimizer hooks in callbacks (#4379)

* Add optimizer hooks in callbacks

* optimizer param

* update test

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Rohit Gupta 2020-10-28 17:45:22 +05:30 committed by GitHub
parent 00cc69aed7
commit b26c71eadf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 2 deletions

View File

@ -166,3 +166,15 @@ class Callback(abc.ABC):
def on_load_checkpoint(self, checkpointed_state):
"""Called when loading a model checkpoint, use to reload state."""
pass
def on_after_backward(self, trainer, pl_module):
"""
Called after loss.backward() and before optimizers do anything.
"""
pass
def on_before_zero_grad(self, trainer, pl_module, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
"""
pass

View File

@ -209,3 +209,17 @@ class TrainerCallbackHookMixin(ABC):
if state:
state = deepcopy(state)
callback.on_load_checkpoint(state)
def on_after_backward(self):
"""
Called after loss.backward() and before optimizers do anything.
"""
for callback in self.callbacks:
callback.on_after_backward(self, self.get_model())
def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
"""
for callback in self.callbacks:
callback.on_before_zero_grad(self, self.get_model(), optimizer)

View File

@ -454,8 +454,7 @@ class TrainLoop:
)
def on_before_zero_grad(self, optimizer):
model = self.trainer.get_model()
model.on_before_zero_grad(optimizer)
self.trainer.call_hook('on_before_zero_grad', optimizer)
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

View File

@ -55,6 +55,8 @@ def test_trainer_callback_system(tmpdir):
self.on_validation_end_called = False
self.on_test_start_called = False
self.on_test_end_called = False
self.on_after_backward_called = False
self.on_before_zero_grad_called = False
def setup(self, trainer, pl_module, stage: str):
assert isinstance(trainer, Trainer)
@ -160,6 +162,14 @@ def test_trainer_callback_system(tmpdir):
_check_args(trainer, pl_module)
self.on_test_end_called = True
def on_after_backward(self, trainer, pl_module):
_check_args(trainer, pl_module)
self.on_after_backward_called = True
def on_before_zero_grad(self, trainer, pl_module, optimizer):
_check_args(trainer, pl_module)
self.on_before_zero_grad_called = True
test_callback = TestCallback()
trainer_options = dict(
@ -197,6 +207,8 @@ def test_trainer_callback_system(tmpdir):
assert not test_callback.on_validation_end_called
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called
assert not test_callback.on_after_backward_called
assert not test_callback.on_before_zero_grad_called
# fit model
trainer = Trainer(**trainer_options)
@ -228,6 +240,8 @@ def test_trainer_callback_system(tmpdir):
assert not test_callback.on_validation_end_called
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called
assert not test_callback.on_after_backward_called
assert not test_callback.on_before_zero_grad_called
trainer.fit(model)
@ -257,6 +271,8 @@ def test_trainer_callback_system(tmpdir):
assert not test_callback.on_test_batch_end_called
assert not test_callback.on_test_start_called
assert not test_callback.on_test_end_called
assert test_callback.on_after_backward_called
assert test_callback.on_before_zero_grad_called
# reset setup teardown callback
test_callback.teardown_called = False
@ -277,3 +293,5 @@ def test_trainer_callback_system(tmpdir):
assert not test_callback.on_validation_end_called
assert not test_callback.on_validation_batch_end_called
assert not test_callback.on_validation_batch_start_called
assert not test_callback.on_after_backward_called
assert not test_callback.on_before_zero_grad_called