Fix rich with uneven refresh rate tracking (#11668)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Rohit Gupta 2022-02-03 15:57:05 +05:30 committed by GitHub
parent 7948ed703d
commit e9065e9d42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 105 additions and 37 deletions

View File

@ -509,12 +509,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue in `RichProgressbar` to display the metrics logged only on main progress bar ([#11690](https://github.com/PyTorchLightning/pytorch-lightning/pull/11690))
- Fixed `RichProgressBar` progress when refresh rate does not evenly divide the total counter ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668))
- Fixed `RichProgressBar` progress validation bar total when using multiple validation runs within a single training epoch ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668))
- The `RichProgressBar` now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
- Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526))
- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))

View File

@ -105,6 +105,9 @@ class ProgressBarBase(Callback):
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
dataloader is of infinite size.
"""
if self.trainer.sanity_checking:
return sum(self.trainer.num_sanity_val_batches)
total_val_batches = 0
if self.trainer.enable_validation:
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0

View File

@ -318,8 +318,6 @@ class RichProgressBar(ProgressBarBase):
def on_sanity_check_start(self, trainer, pl_module):
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
self.refresh()
def on_sanity_check_end(self, trainer, pl_module):
if self.progress is not None:
@ -349,14 +347,13 @@ class RichProgressBar(ProgressBarBase):
self.refresh()
def on_validation_epoch_start(self, trainer, pl_module):
if self.total_val_batches > 0:
total_val_batches = self.total_val_batches
if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"):
# val can be checked multiple times per epoch
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
total_val_batches = self.total_val_batches * val_checks_per_epoch
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
self.refresh()
if trainer.sanity_checking:
self.val_sanity_progress_bar_id = self._add_task(self.total_val_batches, self.sanity_check_description)
else:
self.val_progress_bar_id = self._add_task(
self.total_val_batches, self.validation_description, visible=False
)
self.refresh()
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
@ -364,17 +361,20 @@ class RichProgressBar(ProgressBarBase):
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
leftover = current % self.refresh_rate
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
self.progress.update(progress_bar_id, advance=advance, visible=visible)
self.refresh()
def _should_update(self, current: int, total: int) -> bool:
def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
def on_validation_epoch_end(self, trainer, pl_module):
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)
self.refresh()
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.state.fn == "fit":

View File

@ -178,9 +178,6 @@ class EvaluationLoop(DataLoaderLoop):
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches

View File

@ -1344,6 +1344,9 @@ class Trainer(
# reload dataloaders
val_loop._reload_evaluation_dataloaders()
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]
# run eval step
with torch.no_grad():

View File

@ -46,9 +46,8 @@ def test_rich_progress_bar_refresh_rate_enabled():
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize("dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)])
def test_rich_progress_bar(progress_update, tmpdir, dataset):
def test_rich_progress_bar(tmpdir, dataset):
class TestModel(BoringModel):
def train_dataloader(self):
return DataLoader(dataset=dataset)
@ -62,8 +61,6 @@ def test_rich_progress_bar(progress_update, tmpdir, dataset):
def predict_dataloader(self):
return DataLoader(dataset=dataset)
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
@ -71,16 +68,27 @@ def test_rich_progress_bar(progress_update, tmpdir, dataset):
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
max_steps=1,
max_epochs=1,
callbacks=RichProgressBar(),
)
model = TestModel()
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.fit(model)
# 3 for main progress bar and 1 for val progress bar
assert mocked.call_count == 4
assert progress_update.call_count == 8
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.validate(model)
assert mocked.call_count == 1
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.test(model)
assert mocked.call_count == 1
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.predict(model)
assert mocked.call_count == 1
def test_rich_progress_bar_import_error(monkeypatch):
@ -186,11 +194,20 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):
def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=4,
callbacks=RichProgressBar(refresh_rate=0),
)
trainer.fit(BoringModel())
assert progress_update.call_count == 0
@RunIf(rich=True)
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)]))
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
@ -200,14 +217,26 @@ def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, e
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)
trainer.fit(model)
trainer.progress_bar_callback.on_train_start(trainer, model)
with mock.patch.object(
trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update
) as progress_update:
trainer.fit(model)
assert progress_update.call_count == expected_call_count
assert progress_update.call_count == expected_call_count
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_main_bar.completed == 12
assert fit_main_bar.total == 12
assert fit_main_bar.visible
assert fit_val_bar.completed == 6
assert fit_val_bar.total == 6
assert not fit_val_bar.visible
@RunIf(rich=True)
@pytest.mark.parametrize("limit_val_batches", (1, 5))
def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches):
model = BoringModel()
progress_bar = RichProgressBar()
@ -224,6 +253,36 @@ def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
trainer.fit(model)
assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)
assert progress_bar.progress.tasks[0].total == min(num_sanity_val_steps, limit_val_batches)
@RunIf(rich=True)
def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):
"""Test the completed and total counter for rich progress bar when using val_check_interval."""
progress_bar = RichProgressBar()
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
val_check_interval=2,
max_epochs=1,
limit_train_batches=7,
limit_val_batches=4,
callbacks=[progress_bar],
)
trainer.fit(model)
fit_main_progress_bar = progress_bar.progress.tasks[1]
assert fit_main_progress_bar.completed == 7 + 3 * 4
assert fit_main_progress_bar.total == 7 + 3 * 4
fit_val_bar = progress_bar.progress.tasks[2]
assert fit_val_bar.completed == 4
assert fit_val_bar.total == 4
trainer.validate(model)
val_bar = progress_bar.progress.tasks[0]
assert val_bar.completed == 4
assert val_bar.total == 4
@RunIf(rich=True)