Integrate `total_batch_idx` with progress tracking (#8598)

This commit is contained in:
Carlos Mocholí 2021-08-14 14:08:34 +02:00 committed by GitHub
parent bfeffde8f4
commit 0aa5cc7b77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 11 deletions

View File

@ -12,6 +12,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))

View File

@ -39,8 +39,7 @@ class TrainingEpochLoop(loops.Loop):
self.min_steps: int = min_steps
self.max_steps: int = max_steps
self.global_step: int = 0
# the total batch index across all epochs
self.total_batch_idx: int = 0
# manually tracking which is the last batch is necessary for iterable dataset support
self.is_last_batch: Optional[bool] = None
self.batch_progress = Progress()
self.scheduler_progress = SchedulerProgress()
@ -53,6 +52,13 @@ class TrainingEpochLoop(loops.Loop):
self._warning_cache: WarningCache = WarningCache()
self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
@property
def total_batch_idx(self) -> int:
"""Returns the current batch index (across epochs)"""
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
# but before the next `ready` increase
return self.batch_progress.total.ready - 1
@property
def batch_idx(self) -> int:
"""Returns the current batch index (within this epoch)"""
@ -176,14 +182,9 @@ class TrainingEpochLoop(loops.Loop):
# update plateau LR scheduler after metrics are logged
self.update_lr_schedulers("step", update_plateau_schedulers=True)
self.total_batch_idx += 1
# progress global step according to grads progress
self._increment_accumulated_grad_global_step()
if self.done:
raise StopIteration
def on_run_end(self) -> List[List[STEP_OUTPUT]]:
"""Calls the on_epoch_end hook.
@ -351,7 +352,7 @@ class TrainingEpochLoop(loops.Loop):
"""Increments global step according to grads progress"""
if not self._should_accumulate():
self.global_step = self.trainer.accelerator.update_global_step(
self.total_batch_idx, self.trainer.global_step
self.batch_progress.current.ready, self.trainer.global_step
)
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:

View File

@ -63,12 +63,12 @@ class FitLoop(Loop):
@property
def total_batch_idx(self) -> int:
"""Returns the total number of batches already run (across all epochs)"""
"""Returns the current batch index (across epochs)"""
return self.epoch_loop.total_batch_idx
@property
def batch_idx(self) -> int:
"""Returns the number of batches already run within this epoch"""
"""Returns the current batch index (within this epoch)"""
return self.epoch_loop.batch_idx
@property

View File

@ -184,7 +184,7 @@ def test_accumulation_and_early_stopping(tmpdir):
assert lrfinder.suggestion() != 1e-3
assert len(lrfinder.results["lr"]) == 100
assert lrfinder._total_batch_idx == 200
assert lrfinder._total_batch_idx == 199
def test_suggestion_parameters_work(tmpdir):