faster tests (#3604)
This commit is contained in:
parent
02ce7e86a6
commit
88e6b29bba
|
@ -199,6 +199,7 @@ def test_trainer_reset_correctly(tmpdir):
|
|||
f'Attribute {key} was not reset correctly after learning rate finder'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.parametrize('scale_arg', ['power', 'binsearch', True])
|
||||
def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
|
||||
""" Test possible values for 'batch size auto scaling' Trainer argument. """
|
||||
|
@ -206,13 +207,17 @@ def test_auto_scale_batch_size_trainer_arg(tmpdir, scale_arg):
|
|||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
model = EvalModelTemplate(**hparams)
|
||||
before_batch_size = hparams.get('batch_size')
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=scale_arg)
|
||||
trainer = Trainer(default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
auto_scale_batch_size=scale_arg,
|
||||
gpus=1)
|
||||
trainer.tune(model)
|
||||
after_batch_size = model.batch_size
|
||||
assert before_batch_size != after_batch_size, \
|
||||
'Batch size was not altered after running auto scaling of batch size'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.parametrize('use_hparams', [True, False])
|
||||
def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
|
||||
""" Test that new batch size gets written to the correct hyperparameter attribute. """
|
||||
|
@ -238,7 +243,10 @@ def test_auto_scale_batch_size_set_model_attribute(tmpdir, use_hparams):
|
|||
model = model_class(**hparams)
|
||||
model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit()
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
|
||||
trainer = Trainer(default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
auto_scale_batch_size=True,
|
||||
gpus=1)
|
||||
trainer.tune(model, datamodule_fit)
|
||||
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
|
||||
assert trainer.datamodule == datamodule_fit
|
||||
|
|
Loading…
Reference in New Issue