Move some tests to correct subfolder/file (#1312)

* move some tests to trainer file

* fix imports
This commit is contained in:
Adrian Wälchli 2020-03-31 14:58:46 +02:00 committed by GitHub
parent 6ddb03922a
commit d6646e151a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 108 additions and 109 deletions

View File

@ -1,4 +1,3 @@
import math
import warnings import warnings
import pytest import pytest
@ -14,7 +13,6 @@ from tests.base import (
LightTrainDataloader, LightTrainDataloader,
LightningTestModel, LightningTestModel,
LightTestMixin, LightTestMixin,
LightValidationMixin
) )
@ -157,55 +155,6 @@ def test_running_test_without_val(tmpdir):
tutils.assert_ok_model_acc(trainer) tutils.assert_ok_model_acc(trainer)
def test_disabled_validation():
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
tutils.reset_seed()
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
validation_step_invoked = False
validation_end_invoked = False
def validation_step(self, *args, **kwargs):
self.validation_step_invoked = True
return super().validation_step(*args, **kwargs)
def validation_end(self, *args, **kwargs):
self.validation_end_invoked = True
return super().validation_end(*args, **kwargs)
hparams = tutils.get_default_hparams()
model = CurrentModel(hparams)
trainer_options = dict(
show_progress_bar=False,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.0,
fast_dev_run=False,
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# check that val_percent_check=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 1
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
# check that val_percent_check has no influence when fast_dev_run is turned on
model = CurrentModel(hparams)
trainer_options.update(fast_dev_run=True)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 0
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_single_gpu_batch_parse(): def test_single_gpu_batch_parse():
tutils.reset_seed() tutils.reset_seed()
@ -405,63 +354,5 @@ def test_single_gpu_model(tmpdir):
tutils.run_model_test(trainer_options, model) tutils.run_model_test(trainer_options, model)
def test_nan_loss_detection(tmpdir):
test_step = 8
class InfLossModel(LightTrainDataloader, TestModelBase):
def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
if batch_idx == test_step:
if isinstance(output, dict):
output['loss'] *= torch.tensor(math.inf) # make loss infinite
else:
output /= 0
return output
hparams = tutils.get_default_hparams()
model = InfLossModel(hparams)
# fit model
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
trainer.fit(model)
assert trainer.global_step == test_step
for param in model.parameters():
assert torch.isfinite(param).all()
def test_nan_params_detection(tmpdir):
test_step = 8
class NanParamModel(LightTrainDataloader, TestModelBase):
def on_after_backward(self):
if self.global_step == test_step:
# simulate parameter that became nan
torch.nn.init.constant_(self.c_d1.bias, math.nan)
hparams = tutils.get_default_hparams()
model = NanParamModel(hparams)
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
trainer.fit(model)
assert trainer.global_step == test_step
# after aborting the training loop, model still has nan-valued params
params = torch.cat([param.view(-1) for param in model.parameters()])
assert not torch.isfinite(params).all()
# if __name__ == '__main__': # if __name__ == '__main__':
# pytest.main([__file__]) # pytest.main([__file__])

View File

@ -24,6 +24,7 @@ from tests.base import (
LightValidationMultipleDataloadersMixin, LightValidationMultipleDataloadersMixin,
LightTrainDataloader, LightTrainDataloader,
LightTestDataloader, LightTestDataloader,
LightValidationMixin,
) )
@ -518,3 +519,110 @@ def test_testpass_overrides(tmpdir):
model = LightningTestModel(hparams) model = LightningTestModel(hparams)
Trainer().test(model) Trainer().test(model)
def test_disabled_validation():
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
tutils.reset_seed()
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
validation_step_invoked = False
validation_end_invoked = False
def validation_step(self, *args, **kwargs):
self.validation_step_invoked = True
return super().validation_step(*args, **kwargs)
def validation_end(self, *args, **kwargs):
self.validation_end_invoked = True
return super().validation_end(*args, **kwargs)
hparams = tutils.get_default_hparams()
model = CurrentModel(hparams)
trainer_options = dict(
show_progress_bar=False,
max_epochs=2,
train_percent_check=0.4,
val_percent_check=0.0,
fast_dev_run=False,
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# check that val_percent_check=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 1
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
# check that val_percent_check has no influence when fast_dev_run is turned on
model = CurrentModel(hparams)
trainer_options.update(fast_dev_run=True)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 0
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
def test_nan_loss_detection(tmpdir):
test_step = 8
class InfLossModel(LightTrainDataloader, TestModelBase):
def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
if batch_idx == test_step:
if isinstance(output, dict):
output['loss'] *= torch.tensor(math.inf) # make loss infinite
else:
output /= 0
return output
hparams = tutils.get_default_hparams()
model = InfLossModel(hparams)
# fit model
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
trainer.fit(model)
assert trainer.global_step == test_step
for param in model.parameters():
assert torch.isfinite(param).all()
def test_nan_params_detection(tmpdir):
test_step = 8
class NanParamModel(LightTrainDataloader, TestModelBase):
def on_after_backward(self):
if self.global_step == test_step:
# simulate parameter that became nan
torch.nn.init.constant_(self.c_d1.bias, math.nan)
hparams = tutils.get_default_hparams()
model = NanParamModel(hparams)
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
trainer.fit(model)
assert trainer.global_step == test_step
# after aborting the training loop, model still has nan-valued params
params = torch.cat([param.view(-1) for param in model.parameters()])
assert not torch.isfinite(params).all()