lightning/tests/tests_pytorch/loops/test_training_loop.py

232 lines
8.3 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock
import pytest
import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loops import FitLoop
def test_outputs_format(tmpdir):
"""Tests that outputs objects passed to model hooks and methods are consistent and in the correct format."""
class HookedModel(BoringModel):
def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
self.log("foo", 123)
output["foo"] = 123
return output
@staticmethod
def _check_output(output):
assert "loss" in output
assert "foo" in output
assert output["foo"] == 123
def on_train_batch_end(self, outputs, batch, batch_idx):
HookedModel._check_output(outputs)
super().on_train_batch_end(outputs, batch, batch_idx)
def training_epoch_end(self, outputs):
assert len(outputs) == 2
[HookedModel._check_output(output) for output in outputs]
super().training_epoch_end(outputs)
model = HookedModel()
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=1,
limit_train_batches=2,
limit_test_batches=1,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
@pytest.mark.parametrize("seed_once", (True, False))
def test_training_starts_with_seed(tmpdir, seed_once):
"""Test the behavior of seed_everything on subsequent Trainer runs in combination with different settings of
num_sanity_val_steps (which must not affect the random state)."""
class SeededModel(BoringModel):
def __init__(self):
super().__init__()
self.seen_batches = []
def training_step(self, batch, batch_idx):
self.seen_batches.append(batch.view(-1))
return super().training_step(batch, batch_idx)
def run_training(**trainer_kwargs):
model = SeededModel()
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
return torch.cat(model.seen_batches)
if seed_once:
seed_everything(123)
sequence0 = run_training(default_root_dir=tmpdir, max_steps=2, num_sanity_val_steps=0)
sequence1 = run_training(default_root_dir=tmpdir, max_steps=2, num_sanity_val_steps=2)
assert not torch.allclose(sequence0, sequence1)
else:
seed_everything(123)
sequence0 = run_training(default_root_dir=tmpdir, max_steps=2, num_sanity_val_steps=0)
seed_everything(123)
sequence1 = run_training(default_root_dir=tmpdir, max_steps=2, num_sanity_val_steps=2)
assert torch.allclose(sequence0, sequence1)
@pytest.mark.parametrize(["max_epochs", "batch_idx_"], [(2, 5), (3, 8), (4, 12)])
def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_, tmpdir):
class CurrentModel(BoringModel):
def on_train_batch_start(self, batch, batch_idx):
if batch_idx == batch_idx_:
return -1
model = CurrentModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=10)
trainer.fit(model)
if batch_idx_ > trainer.num_training_batches - 1:
assert trainer.fit_loop.batch_idx == trainer.num_training_batches - 1
assert trainer.global_step == trainer.num_training_batches * max_epochs
else:
assert trainer.fit_loop.batch_idx == batch_idx_
assert trainer.global_step == batch_idx_ * max_epochs
def test_should_stop_mid_epoch(tmpdir):
"""Test that training correctly stops mid epoch and that validation is still called at the right time."""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.validation_called_at = None
def training_step(self, batch, batch_idx):
if batch_idx == 4:
self.trainer.should_stop = True
return super().training_step(batch, batch_idx)
def validation_step(self, *args):
self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step)
return super().validation_step(*args)
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=10, limit_val_batches=1)
trainer.fit(model)
# even though we stopped mid epoch, the fit loop finished normally and the current epoch was increased
assert trainer.current_epoch == 1
assert trainer.global_step == 5
assert model.validation_called_at == (0, 5)
def test_fit_loop_done_log_messages(caplog):
fit_loop = FitLoop(max_epochs=1)
trainer = Mock(spec=Trainer)
fit_loop.trainer = trainer
trainer.should_stop = False
trainer.num_training_batches = 5
assert not fit_loop.done
assert not caplog.messages
trainer.num_training_batches = 0
assert fit_loop.done
assert "No training batches" in caplog.text
caplog.clear()
trainer.num_training_batches = 5
epoch_loop = Mock()
epoch_loop.global_step = 10
fit_loop.connect(epoch_loop=epoch_loop)
fit_loop.max_steps = 10
assert fit_loop.done
assert "max_steps=10` reached" in caplog.text
caplog.clear()
fit_loop.max_steps = 20
fit_loop.epoch_progress.current.processed = 3
fit_loop.max_epochs = 3
trainer.should_stop = True
assert fit_loop.done
assert "max_epochs=3` reached" in caplog.text
caplog.clear()
fit_loop.max_epochs = 5
fit_loop.epoch_loop.min_steps = 0
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
assert fit_loop.done
assert "should_stop` was set" in caplog.text
fit_loop.epoch_loop.min_steps = 100
assert not fit_loop.done
def test_warning_valid_train_step_end(tmpdir):
class ValidTrainStepEndModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self(batch)
return {"output": output}
def training_step_end(self, outputs):
loss = self.loss(outputs["output"])
return loss
# No error is raised
model = ValidTrainStepEndModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
trainer.fit(model)
@pytest.mark.parametrize(
"min_epochs, min_steps, current_epoch, early_stop, fit_loop_done, raise_debug_msg",
[
(4, None, 100, True, True, False),
(4, None, 3, False, False, False),
(4, 10, 3, False, False, False),
(None, 10, 4, True, True, True),
(4, None, 4, True, True, True),
(4, 10, 4, True, True, True),
],
)
def test_should_stop_early_stopping_conditions_met(
caplog, min_epochs, min_steps, current_epoch, early_stop, fit_loop_done, raise_debug_msg
):
"""Test that checks that debug message is logged when users sets `should_stop` and min conditions are met."""
trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0, max_epochs=100)
trainer.num_training_batches = 10
trainer.should_stop = True
trainer.fit_loop.epoch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = (
current_epoch * trainer.num_training_batches
)
trainer.fit_loop.epoch_loop.batch_progress.current.ready = 10
trainer.fit_loop.epoch_progress.current.processed = current_epoch
message = "`Trainer.fit` stopped: `trainer.should_stop` was set."
with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"):
assert trainer.fit_loop.done is fit_loop_done
assert (message in caplog.text) is raise_debug_msg
assert trainer.fit_loop._should_stop_early is early_stop