diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index aae8aada28..60baec3f8b 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -1,4 +1,3 @@ -import math import warnings import pytest @@ -14,7 +13,6 @@ from tests.base import ( LightTrainDataloader, LightningTestModel, LightTestMixin, - LightValidationMixin ) @@ -157,55 +155,6 @@ def test_running_test_without_val(tmpdir): tutils.assert_ok_model_acc(trainer) -def test_disabled_validation(): - """Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`.""" - tutils.reset_seed() - - class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase): - - validation_step_invoked = False - validation_end_invoked = False - - def validation_step(self, *args, **kwargs): - self.validation_step_invoked = True - return super().validation_step(*args, **kwargs) - - def validation_end(self, *args, **kwargs): - self.validation_end_invoked = True - return super().validation_end(*args, **kwargs) - - hparams = tutils.get_default_hparams() - model = CurrentModel(hparams) - - trainer_options = dict( - show_progress_bar=False, - max_epochs=2, - train_percent_check=0.4, - val_percent_check=0.0, - fast_dev_run=False, - ) - - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - # check that val_percent_check=0 turns off validation - assert result == 1, 'training failed to complete' - assert trainer.current_epoch == 1 - assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`' - assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`' - - # check that val_percent_check has no influence when fast_dev_run is turned on - model = CurrentModel(hparams) - trainer_options.update(fast_dev_run=True) - trainer = Trainer(**trainer_options) - result = trainer.fit(model) - - assert result == 1, 'training failed to complete' - assert trainer.current_epoch == 0 - assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`' - assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`' - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_single_gpu_batch_parse(): tutils.reset_seed() @@ -405,63 +354,5 @@ def test_single_gpu_model(tmpdir): tutils.run_model_test(trainer_options, model) -def test_nan_loss_detection(tmpdir): - test_step = 8 - - class InfLossModel(LightTrainDataloader, TestModelBase): - - def training_step(self, batch, batch_idx): - output = super().training_step(batch, batch_idx) - if batch_idx == test_step: - if isinstance(output, dict): - output['loss'] *= torch.tensor(math.inf) # make loss infinite - else: - output /= 0 - return output - - hparams = tutils.get_default_hparams() - model = InfLossModel(hparams) - - # fit model - trainer = Trainer( - default_save_path=tmpdir, - max_steps=(test_step + 1), - ) - - with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'): - trainer.fit(model) - assert trainer.global_step == test_step - - for param in model.parameters(): - assert torch.isfinite(param).all() - - -def test_nan_params_detection(tmpdir): - test_step = 8 - - class NanParamModel(LightTrainDataloader, TestModelBase): - - def on_after_backward(self): - if self.global_step == test_step: - # simulate parameter that became nan - torch.nn.init.constant_(self.c_d1.bias, math.nan) - - hparams = tutils.get_default_hparams() - - model = NanParamModel(hparams) - trainer = Trainer( - default_save_path=tmpdir, - max_steps=(test_step + 1), - ) - - with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'): - trainer.fit(model) - assert trainer.global_step == test_step - - # after aborting the training loop, model still has nan-valued params - params = torch.cat([param.view(-1) for param in model.parameters()]) - assert not torch.isfinite(params).all() - - # if __name__ == '__main__': # pytest.main([__file__]) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 307365223d..12ea59412f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -24,6 +24,7 @@ from tests.base import ( LightValidationMultipleDataloadersMixin, LightTrainDataloader, LightTestDataloader, + LightValidationMixin, ) @@ -518,3 +519,110 @@ def test_testpass_overrides(tmpdir): model = LightningTestModel(hparams) Trainer().test(model) + + +def test_disabled_validation(): + """Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`.""" + tutils.reset_seed() + + class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase): + + validation_step_invoked = False + validation_end_invoked = False + + def validation_step(self, *args, **kwargs): + self.validation_step_invoked = True + return super().validation_step(*args, **kwargs) + + def validation_end(self, *args, **kwargs): + self.validation_end_invoked = True + return super().validation_end(*args, **kwargs) + + hparams = tutils.get_default_hparams() + model = CurrentModel(hparams) + + trainer_options = dict( + show_progress_bar=False, + max_epochs=2, + train_percent_check=0.4, + val_percent_check=0.0, + fast_dev_run=False, + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # check that val_percent_check=0 turns off validation + assert result == 1, 'training failed to complete' + assert trainer.current_epoch == 1 + assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`' + assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`' + + # check that val_percent_check has no influence when fast_dev_run is turned on + model = CurrentModel(hparams) + trainer_options.update(fast_dev_run=True) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + assert result == 1, 'training failed to complete' + assert trainer.current_epoch == 0 + assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`' + assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`' + + +def test_nan_loss_detection(tmpdir): + test_step = 8 + + class InfLossModel(LightTrainDataloader, TestModelBase): + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + if batch_idx == test_step: + if isinstance(output, dict): + output['loss'] *= torch.tensor(math.inf) # make loss infinite + else: + output /= 0 + return output + + hparams = tutils.get_default_hparams() + model = InfLossModel(hparams) + + # fit model + trainer = Trainer( + default_save_path=tmpdir, + max_steps=(test_step + 1), + ) + + with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'): + trainer.fit(model) + assert trainer.global_step == test_step + + for param in model.parameters(): + assert torch.isfinite(param).all() + + +def test_nan_params_detection(tmpdir): + test_step = 8 + + class NanParamModel(LightTrainDataloader, TestModelBase): + + def on_after_backward(self): + if self.global_step == test_step: + # simulate parameter that became nan + torch.nn.init.constant_(self.c_d1.bias, math.nan) + + hparams = tutils.get_default_hparams() + + model = NanParamModel(hparams) + trainer = Trainer( + default_save_path=tmpdir, + max_steps=(test_step + 1), + ) + + with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'): + trainer.fit(model) + assert trainer.global_step == test_step + + # after aborting the training loop, model still has nan-valued params + params = torch.cat([param.view(-1) for param in model.parameters()]) + assert not torch.isfinite(params).all()