Fix num_sanity_val_steps is clipped to limit_val_batches (#2917)

* Fix num_sanity_val_steps according to limit_val_steps

* fix test

* add num_sanity_batches

* pep

* update docstring in test

* add more test

* chlog

* update comments and docstring in test

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Adrian Wälchli <adrian.waelchli@inf.unibe.ch>
Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
This commit is contained in:
Rohit Gupta 2020-08-21 23:41:31 +05:30 committed by GitHub
parent bcdb750976
commit 7cca3859a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 13 deletions

View File

@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))
## [0.9.0] - YYYY-MM-DD
@ -121,7 +122,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))
## [0.8.5] - 2020-07-09
### Added

View File

@ -307,7 +307,7 @@ class ProgressBar(ProgressBarBase):
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
def on_sanity_check_end(self, trainer, pl_module):

View File

@ -377,6 +377,7 @@ class Trainer(
self.logged_metrics = {}
self.num_training_batches = 0
self.num_val_batches = []
self.num_sanity_val_batches = []
self.num_test_batches = []
self.train_dataloader = None
self.test_dataloaders = None
@ -463,9 +464,9 @@ class Trainer(
self.min_steps = min_steps
if num_sanity_val_steps == -1:
self.num_sanity_val_steps = float("inf")
self.num_sanity_val_steps = float('inf')
else:
self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches)
self.num_sanity_val_steps = num_sanity_val_steps
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
@ -1239,7 +1240,6 @@ class Trainer(
self.train()
def _run_sanity_check(self, ref_model, model):
using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
@ -1247,14 +1247,15 @@ class Trainer(
# to make sure program won't crash during val
if should_sanity_check:
self.reset_val_dataloader(ref_model)
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]
# hook and callback
self.running_sanity_check = True
self.on_sanity_check_start()
num_loaders = len(self.val_dataloaders)
max_batches = [self.num_sanity_val_steps] * num_loaders
eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False)
eval_results = self._evaluate(model, self.val_dataloaders, self.num_sanity_val_batches, False)
# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:

View File

@ -907,12 +907,42 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
pytest.param(0.5),
pytest.param(5),
])
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
""" Test that the number of sanity check batches is clipped to limit_val_batches. """
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
num_sanity_val_steps = 4
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=num_sanity_val_steps,
limit_val_batches=limit_val_batches,
max_steps=1,
)
assert trainer.num_sanity_val_steps == num_sanity_val_steps
val_dataloaders = model.val_dataloader__multiple_mixed_length()
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
)
@pytest.mark.parametrize(['limit_val_batches'], [
pytest.param(0.0), # this should run no sanity checks
pytest.param(1),
pytest.param(1.0),
pytest.param(0.3),
])
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
"""
Test that num_sanity_val_steps=-1 runs through all validation data once.
Makes sure this setting is independent of limit_val_batches.
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
limited by "limit_val_batches" Trainer argument.
"""
model = EvalModelTemplate()
model.validation_step = model.validation_step__multiple_dataloaders
@ -920,7 +950,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=-1,
limit_val_batches=limit_val_batches, # should have no influence
limit_val_batches=limit_val_batches,
max_steps=1,
)
assert trainer.num_sanity_val_steps == float('inf')
@ -928,7 +958,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
trainer.fit(model, val_dataloaders=val_dataloaders)
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)
assert mocked.call_count == sum(trainer.num_val_batches)
@pytest.mark.parametrize("trainer_kwargs,expected", [