diff --git a/CHANGELOG.md b/CHANGELOG.md index 08d6458e4a..0ddbff5e9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,9 +93,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335)) -- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) +- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) -- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) +- Fixed `val_progress_bar` total with `num_sanity_val_steps` ([#3751](https://github.com/PyTorchLightning/pytorch-lightning/pull/3751)) + +- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([#3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) - Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764)) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 9bffc9883a..3db81fe322 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -340,8 +340,9 @@ class ProgressBar(ProgressBarBase): def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - self.val_progress_bar = self.init_validation_tqdm() - self.val_progress_bar.total = convert_inf(self.total_val_batches) + if not trainer.running_sanity_check: + self.val_progress_bar = self.init_validation_tqdm() + self.val_progress_bar.total = convert_inf(self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 713bdf3c3c..91eecdcf37 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -193,3 +193,37 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate): trainer.test(model) assert progress_bar.test_batches_seen == progress_bar.total_test_batches + + +@pytest.mark.parametrize(['limit_val_batches', 'expected'], [ + pytest.param(0, 0), + pytest.param(5, 7), +]) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected): + """ + Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. + """ + class CurrentProgressBar(ProgressBar): + def __init__(self): + super().__init__() + self.val_progress_bar_total = 0 + + def on_validation_epoch_end(self, trainer, pl_module): + self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total + + model = EvalModelTemplate() + progress_bar = CurrentProgressBar() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + num_sanity_val_steps=2, + limit_train_batches=0, + limit_val_batches=limit_val_batches, + callbacks=[progress_bar], + logger=False, + checkpoint_callback=False, + early_stop_callback=False, + ) + trainer.fit(model) + assert trainer.progress_bar_callback.val_progress_bar_total == expected diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e0049e7aad..78711b637e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -957,7 +957,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): max_steps=1, ) assert trainer.num_sanity_val_steps == num_sanity_val_steps - val_dataloaders = model.val_dataloader__multiple_mixed_length() @pytest.mark.parametrize(['limit_val_batches'], [ @@ -981,7 +980,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): max_steps=1, ) assert trainer.num_sanity_val_steps == float('inf') - val_dataloaders = model.val_dataloader__multiple() @pytest.mark.parametrize("trainer_kwargs,expected", [