Add back sanity checks (#3846)

* Add back sanity checks

* pep
This commit is contained in:
Rohit Gupta 2020-10-05 02:35:26 +05:30 committed by GitHub
parent 1eda7cfbda
commit d3696052cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 23 additions and 3 deletions

View File

@ -944,7 +944,9 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
pytest.param(5), pytest.param(5),
]) ])
def test_num_sanity_val_steps(tmpdir, limit_val_batches): def test_num_sanity_val_steps(tmpdir, limit_val_batches):
""" Test that the number of sanity check batches is clipped to limit_val_batches. """ """
Test that the number of sanity check batches is clipped to `limit_val_batches`.
"""
model = EvalModelTemplate() model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
@ -958,6 +960,16 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
) )
assert trainer.num_sanity_val_steps == num_sanity_val_steps assert trainer.num_sanity_val_steps == num_sanity_val_steps
with patch.object(
trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step
) as mocked:
val_dataloaders = model.val_dataloader__multiple_mixed_length()
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
)
@pytest.mark.parametrize(['limit_val_batches'], [ @pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks pytest.param(0.0), # this should run no sanity checks
@ -967,8 +979,8 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
]) ])
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
""" """
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as Test that `num_sanity_val_steps=-1` runs through all validation data once, and as many batches as
limited by "limit_val_batches" Trainer argument. limited by `limit_val_batches` Trainer argument.
""" """
model = EvalModelTemplate() model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders model.validation_step = model.validation_step__multiple_dataloaders
@ -981,6 +993,14 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
) )
assert trainer.num_sanity_val_steps == float('inf') assert trainer.num_sanity_val_steps == float('inf')
with patch.object(
trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step
) as mocked:
val_dataloaders = model.val_dataloader__multiple()
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(trainer.num_val_batches)
@pytest.mark.parametrize("trainer_kwargs,expected", [ @pytest.mark.parametrize("trainer_kwargs,expected", [
pytest.param( pytest.param(