parent
1eda7cfbda
commit
d3696052cf
|
@ -944,7 +944,9 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
|
|||
pytest.param(5),
|
||||
])
|
||||
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
||||
""" Test that the number of sanity check batches is clipped to limit_val_batches. """
|
||||
"""
|
||||
Test that the number of sanity check batches is clipped to `limit_val_batches`.
|
||||
"""
|
||||
model = EvalModelTemplate()
|
||||
model.validation_step = model.validation_step__multiple_dataloaders
|
||||
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
||||
|
@ -958,6 +960,16 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
|||
)
|
||||
assert trainer.num_sanity_val_steps == num_sanity_val_steps
|
||||
|
||||
with patch.object(
|
||||
trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step
|
||||
) as mocked:
|
||||
val_dataloaders = model.val_dataloader__multiple_mixed_length()
|
||||
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||
|
||||
assert mocked.call_count == sum(
|
||||
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['limit_val_batches'], [
|
||||
pytest.param(0.0), # this should run no sanity checks
|
||||
|
@ -967,8 +979,8 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
|||
])
|
||||
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
||||
"""
|
||||
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
|
||||
limited by "limit_val_batches" Trainer argument.
|
||||
Test that `num_sanity_val_steps=-1` runs through all validation data once, and as many batches as
|
||||
limited by `limit_val_batches` Trainer argument.
|
||||
"""
|
||||
model = EvalModelTemplate()
|
||||
model.validation_step = model.validation_step__multiple_dataloaders
|
||||
|
@ -981,6 +993,14 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
|||
)
|
||||
assert trainer.num_sanity_val_steps == float('inf')
|
||||
|
||||
with patch.object(
|
||||
trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step
|
||||
) as mocked:
|
||||
val_dataloaders = model.val_dataloader__multiple()
|
||||
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||
|
||||
assert mocked.call_count == sum(trainer.num_val_batches)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer_kwargs,expected", [
|
||||
pytest.param(
|
||||
|
|
Loading…
Reference in New Issue