2020-10-21 18:34:29 +00:00
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from tests.base import EvalModelTemplate
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("num_steps", [1, 2, 3])
|
|
|
|
@patch("torch.Tensor.backward")
|
|
|
|
def test_backward_count_simple(torch_backward, num_steps):
|
|
|
|
""" Test that backward is called exactly once per step. """
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
trainer = Trainer(max_steps=num_steps)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert torch_backward.call_count == num_steps
|
|
|
|
|
|
|
|
torch_backward.reset_mock()
|
|
|
|
|
|
|
|
trainer.test(model)
|
|
|
|
assert torch_backward.call_count == 0
|
|
|
|
|
|
|
|
|
|
|
|
@patch("torch.Tensor.backward")
|
|
|
|
def test_backward_count_with_grad_accumulation(torch_backward):
|
|
|
|
""" Test that backward is called the correct number of times when accumulating gradients. """
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
trainer = Trainer(max_epochs=1, limit_train_batches=6, accumulate_grad_batches=2)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert torch_backward.call_count == 6
|
|
|
|
|
|
|
|
torch_backward.reset_mock()
|
|
|
|
|
2020-10-22 12:58:59 +00:00
|
|
|
trainer = Trainer(max_steps=6, accumulate_grad_batches=2)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert torch_backward.call_count == 12
|
2020-10-21 18:34:29 +00:00
|
|
|
|
|
|
|
|
|
|
|
@patch("torch.Tensor.backward")
|
|
|
|
def test_backward_count_with_closure(torch_backward):
|
|
|
|
""" Using a closure (e.g. with LBFGS) should lead to no extra backward calls. """
|
|
|
|
model = EvalModelTemplate()
|
|
|
|
model.configure_optimizers = model.configure_optimizers__lbfgs
|
|
|
|
trainer = Trainer(max_steps=5)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert torch_backward.call_count == 5
|
|
|
|
|
|
|
|
torch_backward.reset_mock()
|
|
|
|
|
2020-10-22 12:58:59 +00:00
|
|
|
trainer = Trainer(max_steps=5, accumulate_grad_batches=2)
|
2020-10-21 18:34:29 +00:00
|
|
|
trainer.fit(model)
|
2020-10-22 12:58:59 +00:00
|
|
|
assert torch_backward.call_count == 10
|