2020-04-16 16:01:41 +00:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
import tests.base.utils as tutils
|
|
|
|
from pytorch_lightning import Trainer
|
2020-05-04 15:38:08 +00:00
|
|
|
from tests.base import EvalModelTemplate
|
2020-04-16 16:01:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('max_steps', [1, 2, 3])
|
|
|
|
def test_on_before_zero_grad_called(max_steps):
|
|
|
|
|
2020-05-04 15:38:08 +00:00
|
|
|
class CurrentTestModel(EvalModelTemplate):
|
2020-04-16 16:01:41 +00:00
|
|
|
on_before_zero_grad_called = 0
|
|
|
|
|
|
|
|
def on_before_zero_grad(self, optimizer):
|
|
|
|
self.on_before_zero_grad_called += 1
|
|
|
|
|
2020-05-04 15:38:08 +00:00
|
|
|
model = CurrentTestModel(tutils.get_default_hparams())
|
2020-04-16 16:01:41 +00:00
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
max_steps=max_steps,
|
|
|
|
num_sanity_val_steps=5,
|
|
|
|
)
|
|
|
|
assert 0 == model.on_before_zero_grad_called
|
|
|
|
trainer.fit(model)
|
|
|
|
assert max_steps == model.on_before_zero_grad_called
|
|
|
|
|
|
|
|
model.on_before_zero_grad_called = 0
|
|
|
|
trainer.test(model)
|
|
|
|
assert 0 == model.on_before_zero_grad_called
|