parent
1eda7cfbda
commit
d3696052cf
|
@ -944,7 +944,9 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
|
||||||
pytest.param(5),
|
pytest.param(5),
|
||||||
])
|
])
|
||||||
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
||||||
""" Test that the number of sanity check batches is clipped to limit_val_batches. """
|
"""
|
||||||
|
Test that the number of sanity check batches is clipped to `limit_val_batches`.
|
||||||
|
"""
|
||||||
model = EvalModelTemplate()
|
model = EvalModelTemplate()
|
||||||
model.validation_step = model.validation_step__multiple_dataloaders
|
model.validation_step = model.validation_step__multiple_dataloaders
|
||||||
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
||||||
|
@ -958,6 +960,16 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
||||||
)
|
)
|
||||||
assert trainer.num_sanity_val_steps == num_sanity_val_steps
|
assert trainer.num_sanity_val_steps == num_sanity_val_steps
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step
|
||||||
|
) as mocked:
|
||||||
|
val_dataloaders = model.val_dataloader__multiple_mixed_length()
|
||||||
|
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||||
|
|
||||||
|
assert mocked.call_count == sum(
|
||||||
|
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(['limit_val_batches'], [
|
@pytest.mark.parametrize(['limit_val_batches'], [
|
||||||
pytest.param(0.0), # this should run no sanity checks
|
pytest.param(0.0), # this should run no sanity checks
|
||||||
|
@ -967,8 +979,8 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
||||||
])
|
])
|
||||||
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
||||||
"""
|
"""
|
||||||
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
|
Test that `num_sanity_val_steps=-1` runs through all validation data once, and as many batches as
|
||||||
limited by "limit_val_batches" Trainer argument.
|
limited by `limit_val_batches` Trainer argument.
|
||||||
"""
|
"""
|
||||||
model = EvalModelTemplate()
|
model = EvalModelTemplate()
|
||||||
model.validation_step = model.validation_step__multiple_dataloaders
|
model.validation_step = model.validation_step__multiple_dataloaders
|
||||||
|
@ -981,6 +993,14 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
||||||
)
|
)
|
||||||
assert trainer.num_sanity_val_steps == float('inf')
|
assert trainer.num_sanity_val_steps == float('inf')
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
trainer.evaluation_loop, 'evaluation_step', wraps=trainer.evaluation_loop.evaluation_step
|
||||||
|
) as mocked:
|
||||||
|
val_dataloaders = model.val_dataloader__multiple()
|
||||||
|
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||||
|
|
||||||
|
assert mocked.call_count == sum(trainer.num_val_batches)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("trainer_kwargs,expected", [
|
@pytest.mark.parametrize("trainer_kwargs,expected", [
|
||||||
pytest.param(
|
pytest.param(
|
||||||
|
|
Loading…
Reference in New Issue