lightning/tests/trainer/flags/test_fast_dev_run.py

126 lines
4.9 KiB
Python

import os
from unittest.mock import Mock
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers.base import DummyLogger
from tests.helpers import BoringModel
@pytest.mark.parametrize("tuner_alg", ["batch size scaler", "learning rate finder"])
def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
"""Test that tuner algorithms are skipped if fast dev run is enabled."""
model = BoringModel()
model.lr = 0.1 # avoid no-lr-found exception
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
auto_scale_batch_size=(tuner_alg == "batch size scaler"),
auto_lr_find=(tuner_alg == "learning rate finder"),
fast_dev_run=True,
)
expected_message = f"Skipping {tuner_alg} since fast_dev_run is enabled."
with pytest.warns(UserWarning, match=expected_message):
trainer.tune(model)
@pytest.mark.parametrize("fast_dev_run", [1, 4])
def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
"""Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run."""
class FastDevRunModel(BoringModel):
def __init__(self):
super().__init__()
self.training_step_call_count = 0
self.training_epoch_end_call_count = 0
self.validation_step_call_count = 0
self.validation_epoch_end_call_count = 0
self.test_step_call_count = 0
def training_step(self, batch, batch_idx):
self.log("some_metric", torch.tensor(7.0))
self.logger.experiment.dummy_log("some_distribution", torch.randn(7) + batch_idx)
self.training_step_call_count += 1
return super().training_step(batch, batch_idx)
def training_epoch_end(self, outputs):
self.training_epoch_end_call_count += 1
super().training_epoch_end(outputs)
def validation_step(self, batch, batch_idx):
self.validation_step_call_count += 1
return super().validation_step(batch, batch_idx)
def validation_epoch_end(self, outputs):
self.validation_epoch_end_call_count += 1
super().validation_epoch_end(outputs)
def test_step(self, batch, batch_idx):
self.test_step_call_count += 1
return super().test_step(batch, batch_idx)
checkpoint_callback = ModelCheckpoint()
checkpoint_callback.save_checkpoint = Mock()
early_stopping_callback = EarlyStopping()
early_stopping_callback._evaluate_stopping_criteria = Mock()
trainer_config = dict(
default_root_dir=tmpdir,
fast_dev_run=fast_dev_run,
val_check_interval=2,
logger=True,
log_every_n_steps=1,
callbacks=[checkpoint_callback, early_stopping_callback],
)
def _make_fast_dev_run_assertions(trainer, model):
# check the call count for train/val/test step/epoch
assert model.training_step_call_count == fast_dev_run
assert model.training_epoch_end_call_count == 1
assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run
assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1
assert model.test_step_call_count == fast_dev_run
# check trainer arguments
assert trainer.max_steps == fast_dev_run
assert trainer.num_sanity_val_steps == 0
assert trainer.max_epochs == 1
assert trainer.val_check_interval == 1.0
assert trainer.check_val_every_n_epoch == 1
# there should be no logger with fast_dev_run
assert isinstance(trainer.logger, DummyLogger)
# checkpoint callback should not have been called with fast_dev_run
assert trainer.checkpoint_callback == checkpoint_callback
checkpoint_callback.save_checkpoint.assert_not_called()
assert not os.path.exists(checkpoint_callback.dirpath)
# early stopping should not have been called with fast_dev_run
assert trainer.early_stopping_callback == early_stopping_callback
early_stopping_callback._evaluate_stopping_criteria.assert_not_called()
train_val_step_model = FastDevRunModel()
trainer = Trainer(**trainer_config)
trainer.fit(train_val_step_model)
trainer.test(train_val_step_model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_val_step_model)
# -----------------------
# also called once with no val step
# -----------------------
train_step_only_model = FastDevRunModel()
train_step_only_model.validation_step = None
trainer = Trainer(**trainer_config)
trainer.fit(train_step_only_model)
trainer.test(train_step_only_model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
_make_fast_dev_run_assertions(trainer, train_step_only_model)