Fix val_progress_bar total with num_sanity_val_steps (#3751)

* Fix val_progress_bar total with num_sanity_val_steps

* chlog

* Fix val_progress_bar total with num_sanity_val_steps

* move test

* replaced with sanity flag and suggestions
This commit is contained in:
Rohit Gupta 2020-10-04 18:02:18 +05:30 committed by GitHub
parent 4da240ea1b
commit a628d181ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 6 deletions

View File

@ -93,9 +93,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335)) - Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335))
- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630)) - Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630))
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735)) - Fixed `val_progress_bar` total with `num_sanity_val_steps` ([#3751](https://github.com/PyTorchLightning/pytorch-lightning/pull/3751))
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([#3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))
- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764)) - Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))

View File

@ -340,8 +340,9 @@ class ProgressBar(ProgressBarBase):
def on_validation_start(self, trainer, pl_module): def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module) super().on_validation_start(trainer, pl_module)
self.val_progress_bar = self.init_validation_tqdm() if not trainer.running_sanity_check:
self.val_progress_bar.total = convert_inf(self.total_val_batches) self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx) super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)

View File

@ -193,3 +193,37 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
trainer.test(model) trainer.test(model)
assert progress_bar.test_batches_seen == progress_bar.total_test_batches assert progress_bar.test_batches_seen == progress_bar.total_test_batches
@pytest.mark.parametrize(['limit_val_batches', 'expected'], [
pytest.param(0, 0),
pytest.param(5, 7),
])
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected):
"""
Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.
"""
class CurrentProgressBar(ProgressBar):
def __init__(self):
super().__init__()
self.val_progress_bar_total = 0
def on_validation_epoch_end(self, trainer, pl_module):
self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total
model = EvalModelTemplate()
progress_bar = CurrentProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
num_sanity_val_steps=2,
limit_train_batches=0,
limit_val_batches=limit_val_batches,
callbacks=[progress_bar],
logger=False,
checkpoint_callback=False,
early_stop_callback=False,
)
trainer.fit(model)
assert trainer.progress_bar_callback.val_progress_bar_total == expected

View File

@ -957,7 +957,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
max_steps=1, max_steps=1,
) )
assert trainer.num_sanity_val_steps == num_sanity_val_steps assert trainer.num_sanity_val_steps == num_sanity_val_steps
val_dataloaders = model.val_dataloader__multiple_mixed_length()
@pytest.mark.parametrize(['limit_val_batches'], [ @pytest.mark.parametrize(['limit_val_batches'], [
@ -981,7 +980,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
max_steps=1, max_steps=1,
) )
assert trainer.num_sanity_val_steps == float('inf') assert trainer.num_sanity_val_steps == float('inf')
val_dataloaders = model.val_dataloader__multiple()
@pytest.mark.parametrize("trainer_kwargs,expected", [ @pytest.mark.parametrize("trainer_kwargs,expected", [