lightning/tests/trainer/optimization/test_backward_calls.py

53 lines
1.6 KiB
Python

from unittest.mock import patch
import pytest
from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate
@pytest.mark.parametrize("num_steps", [1, 2, 3])
@patch("torch.Tensor.backward")
def test_backward_count_simple(torch_backward, num_steps):
"""Test that backward is called exactly once per step."""
model = EvalModelTemplate()
trainer = Trainer(max_steps=num_steps)
trainer.fit(model)
assert torch_backward.call_count == num_steps
torch_backward.reset_mock()
trainer.test(model)
assert torch_backward.call_count == 0
@patch("torch.Tensor.backward")
def test_backward_count_with_grad_accumulation(torch_backward):
"""Test that backward is called the correct number of times when accumulating gradients."""
model = EvalModelTemplate()
trainer = Trainer(max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2)
trainer.fit(model)
assert torch_backward.call_count == 6
torch_backward.reset_mock()
trainer = Trainer(max_steps=6, accumulate_grad_batches=2)
trainer.fit(model)
assert torch_backward.call_count == 12
@patch("torch.Tensor.backward")
def test_backward_count_with_closure(torch_backward):
"""Using a closure (e.g. with LBFGS) should lead to no extra backward calls."""
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__lbfgs
trainer = Trainer(max_steps=5)
trainer.fit(model)
assert torch_backward.call_count == 5
torch_backward.reset_mock()
trainer = Trainer(max_steps=5, accumulate_grad_batches=2)
trainer.fit(model)
assert torch_backward.call_count == 10