Run main progress bar independent of val progress bar in `TQDMProgressBar` (#12563)
Co-authored-by: carmocca <carlossmocholi@gmail.com>
This commit is contained in:
parent
cf0e3c6250
commit
f4883d6ead
|
@ -102,6 +102,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Run main progress bar updates independent of val progress bar updates in `TQDMProgressBar` ([#12563](https://github.com/PyTorchLightning/pytorch-lightning/pull/12563))
|
||||
|
||||
|
||||
- Avoid calling `average_parameters` multiple times per optimizer step ([#12452](https://github.com/PyTorchLightning/pytorch-lightning/pull/12452))
|
||||
|
||||
|
||||
|
|
|
@ -263,8 +263,9 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
|
||||
|
||||
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
|
||||
if self._should_update(self.train_batch_idx, self.total_train_batches):
|
||||
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
|
||||
current = self.train_batch_idx + self._val_processed
|
||||
if self._should_update(current, self.main_progress_bar.total):
|
||||
_update_n(self.main_progress_bar, current)
|
||||
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
|
@ -289,10 +290,12 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
|
||||
|
||||
def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
|
||||
if self._should_update(self.val_batch_idx, self.total_val_batches_current_dataloader):
|
||||
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
|
||||
_update_n(self.val_progress_bar, self.val_batch_idx)
|
||||
if trainer.state.fn == "fit":
|
||||
_update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed)
|
||||
|
||||
current = self.train_batch_idx + self._val_processed
|
||||
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
|
||||
_update_n(self.main_progress_bar, current)
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._main_progress_bar is not None and trainer.state.fn == "fit":
|
||||
|
@ -313,7 +316,7 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
|
||||
|
||||
def on_test_batch_end(self, *_: Any) -> None:
|
||||
if self._should_update(self.test_batch_idx, self.total_test_batches_current_dataloader):
|
||||
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
|
||||
_update_n(self.test_progress_bar, self.test_batch_idx)
|
||||
|
||||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
|
@ -333,7 +336,7 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
|
||||
|
||||
def on_predict_batch_end(self, *_: Any) -> None:
|
||||
if self._should_update(self.predict_batch_idx, self.total_predict_batches_current_dataloader):
|
||||
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
|
||||
_update_n(self.predict_progress_bar, self.predict_batch_idx)
|
||||
|
||||
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
|
@ -356,8 +359,8 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
s = sep.join(map(str, args))
|
||||
active_progress_bar.write(s, **kwargs)
|
||||
|
||||
def _should_update(self, current: int, total: Union[int, float]) -> bool:
|
||||
return self.refresh_rate > 0 and (current % self.refresh_rate == 0 or current == total)
|
||||
def _should_update(self, current: int, total: int) -> bool:
|
||||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_refresh_rate(refresh_rate: int) -> int:
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
|
@ -347,10 +348,10 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
|
|||
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]],
|
||||
[0, 0, 3, None, None],
|
||||
[1, 0, 3, [1], None],
|
||||
[1, 1, 3, [1, 2], [1]],
|
||||
[1, 1, 3, [2], [1]],
|
||||
[5, 0, 3, [3, 5], None],
|
||||
[5, 2, 3, [3, 5, 7], [2]],
|
||||
[5, 2, 6, [5, 7], [2]],
|
||||
[5, 2, 3, [3, 6, 7], [2]],
|
||||
[5, 2, 6, [6, 7], [2]],
|
||||
],
|
||||
)
|
||||
def test_main_progress_bar_update_amount(
|
||||
|
@ -549,16 +550,56 @@ def test_tqdm_progress_bar_can_be_pickled():
|
|||
pickle.dumps(bar)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, standalone=True)
|
||||
@pytest.mark.parametrize(
|
||||
["total_train_samples", "train_batch_size", "total_val_samples", "val_batch_size", "val_check_interval"],
|
||||
[(8, 4, 2, 1, 0.2), (8, 4, 2, 1, 0.5)],
|
||||
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
|
||||
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])],
|
||||
)
|
||||
def test_progress_bar_max_val_check_interval(
|
||||
tmpdir, total_train_samples, train_batch_size, total_val_samples, val_batch_size, val_check_interval
|
||||
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
|
||||
):
|
||||
limit_batches = 7
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
max_epochs=1,
|
||||
enable_model_summary=False,
|
||||
val_check_interval=val_check_interval,
|
||||
limit_train_batches=limit_batches,
|
||||
limit_val_batches=limit_batches,
|
||||
callbacks=TQDMProgressBar(refresh_rate=3),
|
||||
)
|
||||
with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
|
||||
trainer.fit(model)
|
||||
|
||||
pbar = trainer.progress_bar_callback
|
||||
assert pbar.main_progress_bar.n_values == main_progress_bar_updates
|
||||
assert pbar.val_progress_bar.n_values == val_progress_bar_updates
|
||||
|
||||
val_check_batch = (
|
||||
max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval
|
||||
)
|
||||
assert trainer.val_check_batch == val_check_batch
|
||||
val_checks_per_epoch = math.ceil(limit_batches // val_check_batch)
|
||||
pbar_callback = trainer.progress_bar_callback
|
||||
total_val_batches = limit_batches * val_checks_per_epoch
|
||||
|
||||
assert pbar_callback.val_progress_bar.n == limit_batches
|
||||
assert pbar_callback.val_progress_bar.total == limit_batches
|
||||
assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches
|
||||
assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches
|
||||
assert pbar_callback.is_enabled
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, standalone=True)
|
||||
@pytest.mark.parametrize("val_check_interval", [0.2, 0.5])
|
||||
def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
|
||||
world_size = 2
|
||||
train_data = DataLoader(RandomDataset(32, total_train_samples), batch_size=train_batch_size)
|
||||
total_train_samples = 16
|
||||
train_batch_size = 4
|
||||
total_val_samples = 2
|
||||
val_batch_size = 1
|
||||
train_data = DataLoader(RandomDataset(32, 8), batch_size=train_batch_size)
|
||||
val_data = DataLoader(RandomDataset(32, total_val_samples), batch_size=val_batch_size)
|
||||
|
||||
model = BoringModel()
|
||||
|
@ -585,8 +626,8 @@ def test_progress_bar_max_val_check_interval(
|
|||
assert pbar_callback.val_progress_bar.n == total_val_batches
|
||||
assert pbar_callback.val_progress_bar.total == total_val_batches
|
||||
total_val_batches = total_val_batches * val_checks_per_epoch
|
||||
assert pbar_callback.main_progress_bar.n == total_train_batches + total_val_batches
|
||||
assert pbar_callback.main_progress_bar.total == total_train_batches + total_val_batches
|
||||
assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size
|
||||
assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size
|
||||
assert pbar_callback.is_enabled
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue