lightning/tests/models/test_hooks.py

30 lines
781 B
Python
Raw Normal View History

import pytest
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
@pytest.mark.parametrize('max_steps', [1, 2, 3])
def test_on_before_zero_grad_called(max_steps):
class CurrentTestModel(EvalModelTemplate):
on_before_zero_grad_called = 0
def on_before_zero_grad(self, optimizer):
self.on_before_zero_grad_called += 1
model = CurrentTestModel()
trainer = Trainer(
max_steps=max_steps,
num_sanity_val_steps=5,
)
assert 0 == model.on_before_zero_grad_called
trainer.fit(model)
assert max_steps == model.on_before_zero_grad_called
model.on_before_zero_grad_called = 0
trainer.test(model)
assert 0 == model.on_before_zero_grad_called