Move some tests to correct subfolder/file (#1312)
* move some tests to trainer file * fix imports
This commit is contained in:
parent
6ddb03922a
commit
d6646e151a
|
@ -1,4 +1,3 @@
|
||||||
import math
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -14,7 +13,6 @@ from tests.base import (
|
||||||
LightTrainDataloader,
|
LightTrainDataloader,
|
||||||
LightningTestModel,
|
LightningTestModel,
|
||||||
LightTestMixin,
|
LightTestMixin,
|
||||||
LightValidationMixin
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,55 +155,6 @@ def test_running_test_without_val(tmpdir):
|
||||||
tutils.assert_ok_model_acc(trainer)
|
tutils.assert_ok_model_acc(trainer)
|
||||||
|
|
||||||
|
|
||||||
def test_disabled_validation():
|
|
||||||
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
|
|
||||||
tutils.reset_seed()
|
|
||||||
|
|
||||||
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
|
|
||||||
|
|
||||||
validation_step_invoked = False
|
|
||||||
validation_end_invoked = False
|
|
||||||
|
|
||||||
def validation_step(self, *args, **kwargs):
|
|
||||||
self.validation_step_invoked = True
|
|
||||||
return super().validation_step(*args, **kwargs)
|
|
||||||
|
|
||||||
def validation_end(self, *args, **kwargs):
|
|
||||||
self.validation_end_invoked = True
|
|
||||||
return super().validation_end(*args, **kwargs)
|
|
||||||
|
|
||||||
hparams = tutils.get_default_hparams()
|
|
||||||
model = CurrentModel(hparams)
|
|
||||||
|
|
||||||
trainer_options = dict(
|
|
||||||
show_progress_bar=False,
|
|
||||||
max_epochs=2,
|
|
||||||
train_percent_check=0.4,
|
|
||||||
val_percent_check=0.0,
|
|
||||||
fast_dev_run=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(**trainer_options)
|
|
||||||
result = trainer.fit(model)
|
|
||||||
|
|
||||||
# check that val_percent_check=0 turns off validation
|
|
||||||
assert result == 1, 'training failed to complete'
|
|
||||||
assert trainer.current_epoch == 1
|
|
||||||
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
|
|
||||||
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
|
|
||||||
|
|
||||||
# check that val_percent_check has no influence when fast_dev_run is turned on
|
|
||||||
model = CurrentModel(hparams)
|
|
||||||
trainer_options.update(fast_dev_run=True)
|
|
||||||
trainer = Trainer(**trainer_options)
|
|
||||||
result = trainer.fit(model)
|
|
||||||
|
|
||||||
assert result == 1, 'training failed to complete'
|
|
||||||
assert trainer.current_epoch == 0
|
|
||||||
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
|
|
||||||
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||||
def test_single_gpu_batch_parse():
|
def test_single_gpu_batch_parse():
|
||||||
tutils.reset_seed()
|
tutils.reset_seed()
|
||||||
|
@ -405,63 +354,5 @@ def test_single_gpu_model(tmpdir):
|
||||||
tutils.run_model_test(trainer_options, model)
|
tutils.run_model_test(trainer_options, model)
|
||||||
|
|
||||||
|
|
||||||
def test_nan_loss_detection(tmpdir):
|
|
||||||
test_step = 8
|
|
||||||
|
|
||||||
class InfLossModel(LightTrainDataloader, TestModelBase):
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
output = super().training_step(batch, batch_idx)
|
|
||||||
if batch_idx == test_step:
|
|
||||||
if isinstance(output, dict):
|
|
||||||
output['loss'] *= torch.tensor(math.inf) # make loss infinite
|
|
||||||
else:
|
|
||||||
output /= 0
|
|
||||||
return output
|
|
||||||
|
|
||||||
hparams = tutils.get_default_hparams()
|
|
||||||
model = InfLossModel(hparams)
|
|
||||||
|
|
||||||
# fit model
|
|
||||||
trainer = Trainer(
|
|
||||||
default_save_path=tmpdir,
|
|
||||||
max_steps=(test_step + 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
|
|
||||||
trainer.fit(model)
|
|
||||||
assert trainer.global_step == test_step
|
|
||||||
|
|
||||||
for param in model.parameters():
|
|
||||||
assert torch.isfinite(param).all()
|
|
||||||
|
|
||||||
|
|
||||||
def test_nan_params_detection(tmpdir):
|
|
||||||
test_step = 8
|
|
||||||
|
|
||||||
class NanParamModel(LightTrainDataloader, TestModelBase):
|
|
||||||
|
|
||||||
def on_after_backward(self):
|
|
||||||
if self.global_step == test_step:
|
|
||||||
# simulate parameter that became nan
|
|
||||||
torch.nn.init.constant_(self.c_d1.bias, math.nan)
|
|
||||||
|
|
||||||
hparams = tutils.get_default_hparams()
|
|
||||||
|
|
||||||
model = NanParamModel(hparams)
|
|
||||||
trainer = Trainer(
|
|
||||||
default_save_path=tmpdir,
|
|
||||||
max_steps=(test_step + 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
|
|
||||||
trainer.fit(model)
|
|
||||||
assert trainer.global_step == test_step
|
|
||||||
|
|
||||||
# after aborting the training loop, model still has nan-valued params
|
|
||||||
params = torch.cat([param.view(-1) for param in model.parameters()])
|
|
||||||
assert not torch.isfinite(params).all()
|
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == '__main__':
|
# if __name__ == '__main__':
|
||||||
# pytest.main([__file__])
|
# pytest.main([__file__])
|
||||||
|
|
|
@ -24,6 +24,7 @@ from tests.base import (
|
||||||
LightValidationMultipleDataloadersMixin,
|
LightValidationMultipleDataloadersMixin,
|
||||||
LightTrainDataloader,
|
LightTrainDataloader,
|
||||||
LightTestDataloader,
|
LightTestDataloader,
|
||||||
|
LightValidationMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -518,3 +519,110 @@ def test_testpass_overrides(tmpdir):
|
||||||
|
|
||||||
model = LightningTestModel(hparams)
|
model = LightningTestModel(hparams)
|
||||||
Trainer().test(model)
|
Trainer().test(model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disabled_validation():
|
||||||
|
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
|
||||||
|
tutils.reset_seed()
|
||||||
|
|
||||||
|
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
|
||||||
|
|
||||||
|
validation_step_invoked = False
|
||||||
|
validation_end_invoked = False
|
||||||
|
|
||||||
|
def validation_step(self, *args, **kwargs):
|
||||||
|
self.validation_step_invoked = True
|
||||||
|
return super().validation_step(*args, **kwargs)
|
||||||
|
|
||||||
|
def validation_end(self, *args, **kwargs):
|
||||||
|
self.validation_end_invoked = True
|
||||||
|
return super().validation_end(*args, **kwargs)
|
||||||
|
|
||||||
|
hparams = tutils.get_default_hparams()
|
||||||
|
model = CurrentModel(hparams)
|
||||||
|
|
||||||
|
trainer_options = dict(
|
||||||
|
show_progress_bar=False,
|
||||||
|
max_epochs=2,
|
||||||
|
train_percent_check=0.4,
|
||||||
|
val_percent_check=0.0,
|
||||||
|
fast_dev_run=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(**trainer_options)
|
||||||
|
result = trainer.fit(model)
|
||||||
|
|
||||||
|
# check that val_percent_check=0 turns off validation
|
||||||
|
assert result == 1, 'training failed to complete'
|
||||||
|
assert trainer.current_epoch == 1
|
||||||
|
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
|
||||||
|
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
|
||||||
|
|
||||||
|
# check that val_percent_check has no influence when fast_dev_run is turned on
|
||||||
|
model = CurrentModel(hparams)
|
||||||
|
trainer_options.update(fast_dev_run=True)
|
||||||
|
trainer = Trainer(**trainer_options)
|
||||||
|
result = trainer.fit(model)
|
||||||
|
|
||||||
|
assert result == 1, 'training failed to complete'
|
||||||
|
assert trainer.current_epoch == 0
|
||||||
|
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
|
||||||
|
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
|
||||||
|
|
||||||
|
|
||||||
|
def test_nan_loss_detection(tmpdir):
|
||||||
|
test_step = 8
|
||||||
|
|
||||||
|
class InfLossModel(LightTrainDataloader, TestModelBase):
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
output = super().training_step(batch, batch_idx)
|
||||||
|
if batch_idx == test_step:
|
||||||
|
if isinstance(output, dict):
|
||||||
|
output['loss'] *= torch.tensor(math.inf) # make loss infinite
|
||||||
|
else:
|
||||||
|
output /= 0
|
||||||
|
return output
|
||||||
|
|
||||||
|
hparams = tutils.get_default_hparams()
|
||||||
|
model = InfLossModel(hparams)
|
||||||
|
|
||||||
|
# fit model
|
||||||
|
trainer = Trainer(
|
||||||
|
default_save_path=tmpdir,
|
||||||
|
max_steps=(test_step + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
|
||||||
|
trainer.fit(model)
|
||||||
|
assert trainer.global_step == test_step
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
assert torch.isfinite(param).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_nan_params_detection(tmpdir):
|
||||||
|
test_step = 8
|
||||||
|
|
||||||
|
class NanParamModel(LightTrainDataloader, TestModelBase):
|
||||||
|
|
||||||
|
def on_after_backward(self):
|
||||||
|
if self.global_step == test_step:
|
||||||
|
# simulate parameter that became nan
|
||||||
|
torch.nn.init.constant_(self.c_d1.bias, math.nan)
|
||||||
|
|
||||||
|
hparams = tutils.get_default_hparams()
|
||||||
|
|
||||||
|
model = NanParamModel(hparams)
|
||||||
|
trainer = Trainer(
|
||||||
|
default_save_path=tmpdir,
|
||||||
|
max_steps=(test_step + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
|
||||||
|
trainer.fit(model)
|
||||||
|
assert trainer.global_step == test_step
|
||||||
|
|
||||||
|
# after aborting the training loop, model still has nan-valued params
|
||||||
|
params = torch.cat([param.view(-1) for param in model.parameters()])
|
||||||
|
assert not torch.isfinite(params).all()
|
||||||
|
|
Loading…
Reference in New Issue