126 lines
4.9 KiB
Python
126 lines
4.9 KiB
Python
import os
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|
from pytorch_lightning.loggers.base import DummyLogger
|
|
from tests.helpers import BoringModel
|
|
|
|
|
|
@pytest.mark.parametrize("tuner_alg", ["batch size scaler", "learning rate finder"])
|
|
def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
|
|
"""Test that tuner algorithms are skipped if fast dev run is enabled."""
|
|
|
|
model = BoringModel()
|
|
model.lr = 0.1 # avoid no-lr-found exception
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=2,
|
|
auto_scale_batch_size=(tuner_alg == "batch size scaler"),
|
|
auto_lr_find=(tuner_alg == "learning rate finder"),
|
|
fast_dev_run=True,
|
|
)
|
|
expected_message = f"Skipping {tuner_alg} since fast_dev_run is enabled."
|
|
with pytest.warns(UserWarning, match=expected_message):
|
|
trainer.tune(model)
|
|
|
|
|
|
@pytest.mark.parametrize("fast_dev_run", [1, 4])
|
|
def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
|
|
"""Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run."""
|
|
|
|
class FastDevRunModel(BoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.training_step_call_count = 0
|
|
self.training_epoch_end_call_count = 0
|
|
self.validation_step_call_count = 0
|
|
self.validation_epoch_end_call_count = 0
|
|
self.test_step_call_count = 0
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
self.log("some_metric", torch.tensor(7.0))
|
|
self.logger.experiment.dummy_log("some_distribution", torch.randn(7) + batch_idx)
|
|
self.training_step_call_count += 1
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
def training_epoch_end(self, outputs):
|
|
self.training_epoch_end_call_count += 1
|
|
super().training_epoch_end(outputs)
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
self.validation_step_call_count += 1
|
|
return super().validation_step(batch, batch_idx)
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
self.validation_epoch_end_call_count += 1
|
|
super().validation_epoch_end(outputs)
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
self.test_step_call_count += 1
|
|
return super().test_step(batch, batch_idx)
|
|
|
|
checkpoint_callback = ModelCheckpoint()
|
|
checkpoint_callback.save_checkpoint = Mock()
|
|
early_stopping_callback = EarlyStopping()
|
|
early_stopping_callback._evaluate_stopping_criteria = Mock()
|
|
trainer_config = dict(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=fast_dev_run,
|
|
val_check_interval=2,
|
|
logger=True,
|
|
log_every_n_steps=1,
|
|
callbacks=[checkpoint_callback, early_stopping_callback],
|
|
)
|
|
|
|
def _make_fast_dev_run_assertions(trainer, model):
|
|
# check the call count for train/val/test step/epoch
|
|
assert model.training_step_call_count == fast_dev_run
|
|
assert model.training_epoch_end_call_count == 1
|
|
assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run
|
|
assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1
|
|
assert model.test_step_call_count == fast_dev_run
|
|
|
|
# check trainer arguments
|
|
assert trainer.max_steps == fast_dev_run
|
|
assert trainer.num_sanity_val_steps == 0
|
|
assert trainer.max_epochs == 1
|
|
assert trainer.val_check_interval == 1.0
|
|
assert trainer.check_val_every_n_epoch == 1
|
|
|
|
# there should be no logger with fast_dev_run
|
|
assert isinstance(trainer.logger, DummyLogger)
|
|
|
|
# checkpoint callback should not have been called with fast_dev_run
|
|
assert trainer.checkpoint_callback == checkpoint_callback
|
|
checkpoint_callback.save_checkpoint.assert_not_called()
|
|
assert not os.path.exists(checkpoint_callback.dirpath)
|
|
|
|
# early stopping should not have been called with fast_dev_run
|
|
assert trainer.early_stopping_callback == early_stopping_callback
|
|
early_stopping_callback._evaluate_stopping_criteria.assert_not_called()
|
|
|
|
train_val_step_model = FastDevRunModel()
|
|
trainer = Trainer(**trainer_config)
|
|
trainer.fit(train_val_step_model)
|
|
trainer.test(train_val_step_model)
|
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
_make_fast_dev_run_assertions(trainer, train_val_step_model)
|
|
|
|
# -----------------------
|
|
# also called once with no val step
|
|
# -----------------------
|
|
train_step_only_model = FastDevRunModel()
|
|
train_step_only_model.validation_step = None
|
|
|
|
trainer = Trainer(**trainer_config)
|
|
trainer.fit(train_step_only_model)
|
|
trainer.test(train_step_only_model)
|
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
|
_make_fast_dev_run_assertions(trainer, train_step_only_model)
|