Integrate `total_batch_idx` with progress tracking (#8598)
This commit is contained in:
parent
bfeffde8f4
commit
0aa5cc7b77
|
@ -12,6 +12,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
|
||||
|
||||
|
||||
- Progress tracking
|
||||
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
|
||||
|
||||
|
||||
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
|
||||
|
||||
|
||||
|
|
|
@ -39,8 +39,7 @@ class TrainingEpochLoop(loops.Loop):
|
|||
self.min_steps: int = min_steps
|
||||
self.max_steps: int = max_steps
|
||||
self.global_step: int = 0
|
||||
# the total batch index across all epochs
|
||||
self.total_batch_idx: int = 0
|
||||
# manually tracking which is the last batch is necessary for iterable dataset support
|
||||
self.is_last_batch: Optional[bool] = None
|
||||
self.batch_progress = Progress()
|
||||
self.scheduler_progress = SchedulerProgress()
|
||||
|
@ -53,6 +52,13 @@ class TrainingEpochLoop(loops.Loop):
|
|||
self._warning_cache: WarningCache = WarningCache()
|
||||
self._epoch_output: Optional[List[List[STEP_OUTPUT]]] = None
|
||||
|
||||
@property
|
||||
def total_batch_idx(self) -> int:
|
||||
"""Returns the current batch index (across epochs)"""
|
||||
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
|
||||
# but before the next `ready` increase
|
||||
return self.batch_progress.total.ready - 1
|
||||
|
||||
@property
|
||||
def batch_idx(self) -> int:
|
||||
"""Returns the current batch index (within this epoch)"""
|
||||
|
@ -176,14 +182,9 @@ class TrainingEpochLoop(loops.Loop):
|
|||
# update plateau LR scheduler after metrics are logged
|
||||
self.update_lr_schedulers("step", update_plateau_schedulers=True)
|
||||
|
||||
self.total_batch_idx += 1
|
||||
|
||||
# progress global step according to grads progress
|
||||
self._increment_accumulated_grad_global_step()
|
||||
|
||||
if self.done:
|
||||
raise StopIteration
|
||||
|
||||
def on_run_end(self) -> List[List[STEP_OUTPUT]]:
|
||||
"""Calls the on_epoch_end hook.
|
||||
|
||||
|
@ -351,7 +352,7 @@ class TrainingEpochLoop(loops.Loop):
|
|||
"""Increments global step according to grads progress"""
|
||||
if not self._should_accumulate():
|
||||
self.global_step = self.trainer.accelerator.update_global_step(
|
||||
self.total_batch_idx, self.trainer.global_step
|
||||
self.batch_progress.current.ready, self.trainer.global_step
|
||||
)
|
||||
|
||||
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
|
||||
|
|
|
@ -63,12 +63,12 @@ class FitLoop(Loop):
|
|||
|
||||
@property
|
||||
def total_batch_idx(self) -> int:
|
||||
"""Returns the total number of batches already run (across all epochs)"""
|
||||
"""Returns the current batch index (across epochs)"""
|
||||
return self.epoch_loop.total_batch_idx
|
||||
|
||||
@property
|
||||
def batch_idx(self) -> int:
|
||||
"""Returns the number of batches already run within this epoch"""
|
||||
"""Returns the current batch index (within this epoch)"""
|
||||
return self.epoch_loop.batch_idx
|
||||
|
||||
@property
|
||||
|
|
|
@ -184,7 +184,7 @@ def test_accumulation_and_early_stopping(tmpdir):
|
|||
|
||||
assert lrfinder.suggestion() != 1e-3
|
||||
assert len(lrfinder.results["lr"]) == 100
|
||||
assert lrfinder._total_batch_idx == 200
|
||||
assert lrfinder._total_batch_idx == 199
|
||||
|
||||
|
||||
def test_suggestion_parameters_work(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue