diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 78711b637e..d27be6513b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -944,7 +944,9 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): pytest.param(5), ]) def test_num_sanity_val_steps(tmpdir, limit_val_batches): - """ Test that the number of sanity check batches is clipped to limit_val_batches. """ + """ + Test that the number of sanity check batches is clipped to `limit_val_batches`. + """ model = EvalModelTemplate() model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders @@ -958,6 +960,16 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): ) assert trainer.num_sanity_val_steps == num_sanity_val_steps + with patch.object( + trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step + ) as mocked: + val_dataloaders = model.val_dataloader__multiple_mixed_length() + trainer.fit(model, val_dataloaders=val_dataloaders) + + assert mocked.call_count == sum( + min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches + ) + @pytest.mark.parametrize(['limit_val_batches'], [ pytest.param(0.0), # this should run no sanity checks @@ -967,8 +979,8 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): ]) def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): """ - Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as - limited by "limit_val_batches" Trainer argument. + Test that `num_sanity_val_steps=-1` runs through all validation data once, and as many batches as + limited by `limit_val_batches` Trainer argument. """ model = EvalModelTemplate() model.validation_step = model.validation_step__multiple_dataloaders @@ -981,6 +993,14 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): ) assert trainer.num_sanity_val_steps == float('inf') + with patch.object( + trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step + ) as mocked: + val_dataloaders = model.val_dataloader__multiple() + trainer.fit(model, val_dataloaders=val_dataloaders) + + assert mocked.call_count == sum(trainer.num_val_batches) + @pytest.mark.parametrize("trainer_kwargs,expected", [ pytest.param(