Move some tests to correct subfolder/file (#1312)
* move some tests to trainer file * fix imports
This commit is contained in:
parent
6ddb03922a
commit
d6646e151a
|
@ -1,4 +1,3 @@
|
|||
import math
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
@ -14,7 +13,6 @@ from tests.base import (
|
|||
LightTrainDataloader,
|
||||
LightningTestModel,
|
||||
LightTestMixin,
|
||||
LightValidationMixin
|
||||
)
|
||||
|
||||
|
||||
|
@ -157,55 +155,6 @@ def test_running_test_without_val(tmpdir):
|
|||
tutils.assert_ok_model_acc(trainer)
|
||||
|
||||
|
||||
def test_disabled_validation():
|
||||
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
|
||||
|
||||
validation_step_invoked = False
|
||||
validation_end_invoked = False
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
self.validation_step_invoked = True
|
||||
return super().validation_step(*args, **kwargs)
|
||||
|
||||
def validation_end(self, *args, **kwargs):
|
||||
self.validation_end_invoked = True
|
||||
return super().validation_end(*args, **kwargs)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.0,
|
||||
fast_dev_run=False,
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# check that val_percent_check=0 turns off validation
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch == 1
|
||||
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
|
||||
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
|
||||
|
||||
# check that val_percent_check has no influence when fast_dev_run is turned on
|
||||
model = CurrentModel(hparams)
|
||||
trainer_options.update(fast_dev_run=True)
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch == 0
|
||||
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
|
||||
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_single_gpu_batch_parse():
|
||||
tutils.reset_seed()
|
||||
|
@ -405,63 +354,5 @@ def test_single_gpu_model(tmpdir):
|
|||
tutils.run_model_test(trainer_options, model)
|
||||
|
||||
|
||||
def test_nan_loss_detection(tmpdir):
|
||||
test_step = 8
|
||||
|
||||
class InfLossModel(LightTrainDataloader, TestModelBase):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = super().training_step(batch, batch_idx)
|
||||
if batch_idx == test_step:
|
||||
if isinstance(output, dict):
|
||||
output['loss'] *= torch.tensor(math.inf) # make loss infinite
|
||||
else:
|
||||
output /= 0
|
||||
return output
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = InfLossModel(hparams)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_steps=(test_step + 1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == test_step
|
||||
|
||||
for param in model.parameters():
|
||||
assert torch.isfinite(param).all()
|
||||
|
||||
|
||||
def test_nan_params_detection(tmpdir):
|
||||
test_step = 8
|
||||
|
||||
class NanParamModel(LightTrainDataloader, TestModelBase):
|
||||
|
||||
def on_after_backward(self):
|
||||
if self.global_step == test_step:
|
||||
# simulate parameter that became nan
|
||||
torch.nn.init.constant_(self.c_d1.bias, math.nan)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
|
||||
model = NanParamModel(hparams)
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_steps=(test_step + 1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == test_step
|
||||
|
||||
# after aborting the training loop, model still has nan-valued params
|
||||
params = torch.cat([param.view(-1) for param in model.parameters()])
|
||||
assert not torch.isfinite(params).all()
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# pytest.main([__file__])
|
||||
|
|
|
@ -24,6 +24,7 @@ from tests.base import (
|
|||
LightValidationMultipleDataloadersMixin,
|
||||
LightTrainDataloader,
|
||||
LightTestDataloader,
|
||||
LightValidationMixin,
|
||||
)
|
||||
|
||||
|
||||
|
@ -518,3 +519,110 @@ def test_testpass_overrides(tmpdir):
|
|||
|
||||
model = LightningTestModel(hparams)
|
||||
Trainer().test(model)
|
||||
|
||||
|
||||
def test_disabled_validation():
|
||||
"""Verify that `val_percent_check=0` disables the validation loop unless `fast_dev_run=True`."""
|
||||
tutils.reset_seed()
|
||||
|
||||
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):
|
||||
|
||||
validation_step_invoked = False
|
||||
validation_end_invoked = False
|
||||
|
||||
def validation_step(self, *args, **kwargs):
|
||||
self.validation_step_invoked = True
|
||||
return super().validation_step(*args, **kwargs)
|
||||
|
||||
def validation_end(self, *args, **kwargs):
|
||||
self.validation_end_invoked = True
|
||||
return super().validation_end(*args, **kwargs)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = CurrentModel(hparams)
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.0,
|
||||
fast_dev_run=False,
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# check that val_percent_check=0 turns off validation
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch == 1
|
||||
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
|
||||
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
|
||||
|
||||
# check that val_percent_check has no influence when fast_dev_run is turned on
|
||||
model = CurrentModel(hparams)
|
||||
trainer_options.update(fast_dev_run=True)
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'training failed to complete'
|
||||
assert trainer.current_epoch == 0
|
||||
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
|
||||
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
|
||||
|
||||
|
||||
def test_nan_loss_detection(tmpdir):
|
||||
test_step = 8
|
||||
|
||||
class InfLossModel(LightTrainDataloader, TestModelBase):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = super().training_step(batch, batch_idx)
|
||||
if batch_idx == test_step:
|
||||
if isinstance(output, dict):
|
||||
output['loss'] *= torch.tensor(math.inf) # make loss infinite
|
||||
else:
|
||||
output /= 0
|
||||
return output
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = InfLossModel(hparams)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_steps=(test_step + 1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == test_step
|
||||
|
||||
for param in model.parameters():
|
||||
assert torch.isfinite(param).all()
|
||||
|
||||
|
||||
def test_nan_params_detection(tmpdir):
|
||||
test_step = 8
|
||||
|
||||
class NanParamModel(LightTrainDataloader, TestModelBase):
|
||||
|
||||
def on_after_backward(self):
|
||||
if self.global_step == test_step:
|
||||
# simulate parameter that became nan
|
||||
torch.nn.init.constant_(self.c_d1.bias, math.nan)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
|
||||
model = NanParamModel(hparams)
|
||||
trainer = Trainer(
|
||||
default_save_path=tmpdir,
|
||||
max_steps=(test_step + 1),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == test_step
|
||||
|
||||
# after aborting the training loop, model still has nan-valued params
|
||||
params = torch.cat([param.view(-1) for param in model.parameters()])
|
||||
assert not torch.isfinite(params).all()
|
||||
|
|
Loading…
Reference in New Issue