22 lines
848 B
Python
22 lines
848 B
Python
import pytest
|
|
from pytorch_lightning import Trainer
|
|
from tests.base import EvalModelTemplate
|
|
|
|
|
|
@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder'])
|
|
def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
|
|
""" Test that tuner algorithms are skipped if fast dev run is enabled """
|
|
|
|
hparams = EvalModelTemplate.get_default_hparams()
|
|
model = EvalModelTemplate(**hparams)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=2,
|
|
auto_scale_batch_size=True if tuner_alg == 'batch size scaler' else False,
|
|
auto_lr_find=True if tuner_alg == 'learning rate finder' else False,
|
|
fast_dev_run=True
|
|
)
|
|
expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`'
|
|
with pytest.warns(UserWarning, match=expected_message):
|
|
trainer.tune(model)
|