Fix val_progress_bar total with num_sanity_val_steps (#3751)
* Fix val_progress_bar total with num_sanity_val_steps * chlog * Fix val_progress_bar total with num_sanity_val_steps * move test * replaced with sanity flag and suggestions
This commit is contained in:
parent
4da240ea1b
commit
a628d181ee
|
@ -93,9 +93,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335))
|
- Fixed determinism in `DDPSpawnBackend` when using `seed_everything` in main process ([#3335](https://github.com/PyTorchLightning/pytorch-lightning/pull/3335))
|
||||||
|
|
||||||
- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630))
|
- Fixed `ModelCheckpoint` `period` to actually save every `period` epochs ([#3630](https://github.com/PyTorchLightning/pytorch-lightning/pull/3630))
|
||||||
|
|
||||||
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))
|
- Fixed `val_progress_bar` total with `num_sanity_val_steps` ([#3751](https://github.com/PyTorchLightning/pytorch-lightning/pull/3751))
|
||||||
|
|
||||||
|
- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([#3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))
|
||||||
|
|
||||||
- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))
|
- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))
|
||||||
|
|
||||||
|
|
|
@ -340,8 +340,9 @@ class ProgressBar(ProgressBarBase):
|
||||||
|
|
||||||
def on_validation_start(self, trainer, pl_module):
|
def on_validation_start(self, trainer, pl_module):
|
||||||
super().on_validation_start(trainer, pl_module)
|
super().on_validation_start(trainer, pl_module)
|
||||||
self.val_progress_bar = self.init_validation_tqdm()
|
if not trainer.running_sanity_check:
|
||||||
self.val_progress_bar.total = convert_inf(self.total_val_batches)
|
self.val_progress_bar = self.init_validation_tqdm()
|
||||||
|
self.val_progress_bar.total = convert_inf(self.total_val_batches)
|
||||||
|
|
||||||
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
def on_validation_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
|
||||||
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
|
super().on_validation_batch_end(trainer, pl_module, batch, batch_idx, dataloader_idx)
|
||||||
|
|
|
@ -193,3 +193,37 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate):
|
||||||
|
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
assert progress_bar.test_batches_seen == progress_bar.total_test_batches
|
assert progress_bar.test_batches_seen == progress_bar.total_test_batches
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['limit_val_batches', 'expected'], [
|
||||||
|
pytest.param(0, 0),
|
||||||
|
pytest.param(5, 7),
|
||||||
|
])
|
||||||
|
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches, expected):
|
||||||
|
"""
|
||||||
|
Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.
|
||||||
|
"""
|
||||||
|
class CurrentProgressBar(ProgressBar):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.val_progress_bar_total = 0
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self, trainer, pl_module):
|
||||||
|
self.val_progress_bar_total += trainer.progress_bar_callback.val_progress_bar.total
|
||||||
|
|
||||||
|
model = EvalModelTemplate()
|
||||||
|
progress_bar = CurrentProgressBar()
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=1,
|
||||||
|
num_sanity_val_steps=2,
|
||||||
|
limit_train_batches=0,
|
||||||
|
limit_val_batches=limit_val_batches,
|
||||||
|
callbacks=[progress_bar],
|
||||||
|
logger=False,
|
||||||
|
checkpoint_callback=False,
|
||||||
|
early_stop_callback=False,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
assert trainer.progress_bar_callback.val_progress_bar_total == expected
|
||||||
|
|
|
@ -957,7 +957,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
||||||
max_steps=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
assert trainer.num_sanity_val_steps == num_sanity_val_steps
|
assert trainer.num_sanity_val_steps == num_sanity_val_steps
|
||||||
val_dataloaders = model.val_dataloader__multiple_mixed_length()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(['limit_val_batches'], [
|
@pytest.mark.parametrize(['limit_val_batches'], [
|
||||||
|
@ -981,7 +980,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
||||||
max_steps=1,
|
max_steps=1,
|
||||||
)
|
)
|
||||||
assert trainer.num_sanity_val_steps == float('inf')
|
assert trainer.num_sanity_val_steps == float('inf')
|
||||||
val_dataloaders = model.val_dataloader__multiple()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("trainer_kwargs,expected", [
|
@pytest.mark.parametrize("trainer_kwargs,expected", [
|
||||||
|
|
Loading…
Reference in New Issue