faster tests (#3604)

This commit is contained in:
Nicki Skafte 2020-09-22 13:37:34 +02:00 committed by GitHub
parent 02ce7e86a6
commit 88e6b29bba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 2 deletions

View File

@ -199,6 +199,7 @@ def test_trainer_reset_correctly(tmpdir):
f'Attribute {key} was not reset correctly after learning rate finder'
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True])
def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
""" Test possible values for 'batch size auto scaling' Trainer argument. """
@ -206,13 +207,17 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
before_batch_size = hparams.get('batch_size')
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=scale_arg)
trainer = Trainer(default_root_dir=tmpdir,
max_epochs=1,
auto_scale_batch_size=scale_arg,
gpus=1)
trainer.tune(model)
after_batch_size = model.batch_size
assert before_batch_size != after_batch_size, \
'Batch size was not altered after running auto scaling of batch size'
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.parametrize('use_hparams', [True, False])
def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
""" Test that new batch size gets written to the correct hyperparameter attribute. """
@ -238,7 +243,10 @@ def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
model = model_class(**hparams)
model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer = Trainer(default_root_dir=tmpdir,
max_epochs=1,
auto_scale_batch_size=True,
gpus=1)
trainer.tune(model, datamodule_fit)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert trainer.datamodule == datamodule_fit