From 7cca3859a7b97a9ab4a6c6fb5f36ff94bff7f218 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 21 Aug 2020 23:41:31 +0530 Subject: [PATCH] Fix num_sanity_val_steps is clipped to limit_val_batches (#2917) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix num_sanity_val_steps according to limit_val_steps * fix test * add num_sanity_batches * pep * update docstring in test * add more test * chlog * update comments and docstring in test Co-authored-by: Adrian Wälchli Co-authored-by: Adrian Wälchli Co-authored-by: Ananya Harsh Jha --- CHANGELOG.md | 2 +- pytorch_lightning/callbacks/progress.py | 2 +- pytorch_lightning/trainer/trainer.py | 13 ++++---- tests/trainer/test_trainer.py | 40 +++++++++++++++++++++---- 4 files changed, 44 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16afd42b3e..fe4ea78c05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917)) ## [0.9.0] - YYYY-MM-DD @@ -121,7 +122,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045)) - Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042)) - ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 16cdace5ce..8b9ab5cd56 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -307,7 +307,7 @@ class ProgressBar(ProgressBarBase): def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders)) + self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches)) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9d462ef8f3..d372874f16 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -377,6 +377,7 @@ class Trainer( self.logged_metrics = {} self.num_training_batches = 0 self.num_val_batches = [] + self.num_sanity_val_batches = [] self.num_test_batches = [] self.train_dataloader = None self.test_dataloaders = None @@ -463,9 +464,9 @@ class Trainer( self.min_steps = min_steps if num_sanity_val_steps == -1: - self.num_sanity_val_steps = float("inf") + self.num_sanity_val_steps = float('inf') else: - self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches) + self.num_sanity_val_steps = num_sanity_val_steps self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch @@ -1239,7 +1240,6 @@ class Trainer( self.train() def _run_sanity_check(self, ref_model, model): - using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step') should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0 @@ -1247,14 +1247,15 @@ class Trainer( # to make sure program won't crash during val if should_sanity_check: self.reset_val_dataloader(ref_model) + self.num_sanity_val_batches = [ + min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches + ] # hook and callback self.running_sanity_check = True self.on_sanity_check_start() - num_loaders = len(self.val_dataloaders) - max_batches = [self.num_sanity_val_steps] * num_loaders - eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False) + eval_results = self._evaluate(model, self.val_dataloaders, self.num_sanity_val_batches, False) # allow no returns from eval if eval_results is not None and len(eval_results) > 0: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b27bd97bd2..6fdeb270d9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -907,12 +907,42 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): pytest.param(0.0), # this should run no sanity checks pytest.param(1), pytest.param(1.0), - pytest.param(0.3), + pytest.param(0.5), + pytest.param(5), ]) def test_num_sanity_val_steps(tmpdir, limit_val_batches): + """ Test that the number of sanity check batches is clipped to limit_val_batches. """ + model = EvalModelTemplate() + model.validation_step = model.validation_step__multiple_dataloaders + model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders + num_sanity_val_steps = 4 + + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=num_sanity_val_steps, + limit_val_batches=limit_val_batches, + max_steps=1, + ) + assert trainer.num_sanity_val_steps == num_sanity_val_steps + val_dataloaders = model.val_dataloader__multiple_mixed_length() + + with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: + trainer.fit(model, val_dataloaders=val_dataloaders) + assert mocked.call_count == sum( + min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches + ) + + +@pytest.mark.parametrize(['limit_val_batches'], [ + pytest.param(0.0), # this should run no sanity checks + pytest.param(1), + pytest.param(1.0), + pytest.param(0.3), +]) +def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): """ - Test that num_sanity_val_steps=-1 runs through all validation data once. - Makes sure this setting is independent of limit_val_batches. + Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as + limited by "limit_val_batches" Trainer argument. """ model = EvalModelTemplate() model.validation_step = model.validation_step__multiple_dataloaders @@ -920,7 +950,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=-1, - limit_val_batches=limit_val_batches, # should have no influence + limit_val_batches=limit_val_batches, max_steps=1, ) assert trainer.num_sanity_val_steps == float('inf') @@ -928,7 +958,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: trainer.fit(model, val_dataloaders=val_dataloaders) - assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders) + assert mocked.call_count == sum(trainer.num_val_batches) @pytest.mark.parametrize("trainer_kwargs,expected", [