diff --git a/CHANGELOG.md b/CHANGELOG.md index dc89c5bcf0..2a11ec0378 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) +- Progress tracking + * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598) + + - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 213be0de21..894f4e9197 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -39,8 +39,7 @@ class TrainingEpochLoop(loops.Loop): self.min_steps: int = min_steps self.max_steps: int = max_steps self.global_step: int = 0 - # the total batch index across all epochs - self.total_batch_idx: int = 0 + # manually tracking which is the last batch is necessary for iterable dataset support self.is_last_batch: Optional[bool] = None self.batch_progress = Progress() self.scheduler_progress = SchedulerProgress() @@ -53,6 +52,13 @@ class TrainingEpochLoop(loops.Loop): self._warning_cache: WarningCache = WarningCache() self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None + @property + def total_batch_idx(self) -> int: + """Returns the current batch index (across epochs)""" + # use `ready` instead of `completed` in case this is accessed after `completed` has been increased + # but before the next `ready` increase + return self.batch_progress.total.ready - 1 + @property def batch_idx(self) -> int: """Returns the current batch index (within this epoch)""" @@ -176,14 +182,9 @@ class TrainingEpochLoop(loops.Loop): # update plateau LR scheduler after metrics are logged self.update_lr_schedulers("step", update_plateau_schedulers=True) - self.total_batch_idx += 1 - # progress global step according to grads progress self._increment_accumulated_grad_global_step() - if self.done: - raise StopIteration - def on_run_end(self) -> List[List[STEP_OUTPUT]]: """Calls the on_epoch_end hook. @@ -351,7 +352,7 @@ class TrainingEpochLoop(loops.Loop): """Increments global step according to grads progress""" if not self._should_accumulate(): self.global_step = self.trainer.accelerator.update_global_step( - self.total_batch_idx, self.trainer.global_step + self.batch_progress.current.ready, self.trainer.global_step ) def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index ed26d50d6c..b77f186453 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -63,12 +63,12 @@ class FitLoop(Loop): @property def total_batch_idx(self) -> int: - """Returns the total number of batches already run (across all epochs)""" + """Returns the current batch index (across epochs)""" return self.epoch_loop.total_batch_idx @property def batch_idx(self) -> int: - """Returns the number of batches already run within this epoch""" + """Returns the current batch index (within this epoch)""" return self.epoch_loop.batch_idx @property diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index a4e46fb619..de1873ee39 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -184,7 +184,7 @@ def test_accumulation_and_early_stopping(tmpdir): assert lrfinder.suggestion() != 1e-3 assert len(lrfinder.results["lr"]) == 100 - assert lrfinder._total_batch_idx == 200 + assert lrfinder._total_batch_idx == 199 def test_suggestion_parameters_work(tmpdir):