Rename `reset_on_epoch` to `reset_on_run` (#9658)
This commit is contained in:
parent
8fcdcb598b
commit
d02fc2b728
|
@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
|
||||
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))
|
||||
* Use `completed` over `processed` in `reset_on_restart` ([#9656](https://github.com/PyTorchLightning/pytorch-lightning/pull/9656))
|
||||
* Rename `reset_on_epoch` to `reset_on_run` ([#9658](https://github.com/PyTorchLightning/pytorch-lightning/pull/9658))
|
||||
|
||||
|
||||
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
|
||||
|
|
|
@ -56,9 +56,9 @@ class DataLoaderLoop(Loop):
|
|||
def reset(self) -> None:
|
||||
"""Resets the internal state."""
|
||||
if not self.restarting:
|
||||
self.dataloader_progress.current.reset()
|
||||
self.dataloader_progress.reset_on_run()
|
||||
else:
|
||||
self.dataloader_progress.current.reset_on_restart()
|
||||
self.dataloader_progress.reset_on_restart()
|
||||
|
||||
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
|
||||
self.dataloader_progress.increment_ready()
|
||||
|
|
|
@ -58,9 +58,9 @@ class EvaluationEpochLoop(Loop):
|
|||
self.outputs = []
|
||||
|
||||
if not self.restarting:
|
||||
self.batch_progress.current.reset()
|
||||
self.batch_progress.reset_on_run()
|
||||
else:
|
||||
self.batch_progress.current.reset_on_restart()
|
||||
self.batch_progress.reset_on_restart()
|
||||
|
||||
def on_run_start(
|
||||
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int
|
||||
|
|
|
@ -46,7 +46,7 @@ class PredictionEpochLoop(Loop):
|
|||
"""Resets the loops internal state."""
|
||||
self._all_batch_indices: List[int] = []
|
||||
self.predictions: List[Any] = []
|
||||
self.batch_progress.current.reset()
|
||||
self.batch_progress.reset_on_run()
|
||||
|
||||
def on_run_start(
|
||||
self,
|
||||
|
|
|
@ -96,12 +96,12 @@ class TrainingEpochLoop(loops.Loop):
|
|||
assert self.batch_loop.optimizer_loop is not None
|
||||
if self.restarting:
|
||||
self.batch_progress.reset_on_restart()
|
||||
self.scheduler_progress.current.reset_on_restart()
|
||||
self.scheduler_progress.reset_on_restart()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
|
||||
else:
|
||||
self.batch_progress.reset_on_epoch()
|
||||
self.scheduler_progress.reset_on_epoch()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
|
||||
self.batch_progress.reset_on_run()
|
||||
self.scheduler_progress.reset_on_run()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_run()
|
||||
|
||||
# track epoch output
|
||||
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
|
||||
|
|
|
@ -175,7 +175,7 @@ class FitLoop(Loop):
|
|||
def reset(self) -> None:
|
||||
"""Resets the internal state of this loop."""
|
||||
if self.restarting:
|
||||
self.epoch_progress.current.reset_on_restart()
|
||||
self.epoch_progress.reset_on_restart()
|
||||
|
||||
def on_run_start(self) -> None:
|
||||
"""Calls the ``on_train_start`` hook."""
|
||||
|
|
|
@ -151,6 +151,9 @@ class Progress(BaseProgress):
|
|||
def reset_on_epoch(self) -> None:
|
||||
self.current.reset()
|
||||
|
||||
def reset_on_run(self) -> None:
|
||||
self.current.reset()
|
||||
|
||||
def reset_on_restart(self) -> None:
|
||||
self.current.reset_on_restart()
|
||||
|
||||
|
@ -188,8 +191,8 @@ class BatchProgress(Progress):
|
|||
|
||||
is_last_batch: bool = False
|
||||
|
||||
def reset_on_epoch(self) -> None:
|
||||
super().reset_on_epoch()
|
||||
def reset_on_run(self) -> None:
|
||||
super().reset_on_run()
|
||||
self.is_last_batch = False
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
|
@ -224,9 +227,9 @@ class OptimizerProgress(BaseProgress):
|
|||
step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker))
|
||||
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))
|
||||
|
||||
def reset_on_epoch(self) -> None:
|
||||
self.step.reset_on_epoch()
|
||||
self.zero_grad.reset_on_epoch()
|
||||
def reset_on_run(self) -> None:
|
||||
self.step.reset_on_run()
|
||||
self.zero_grad.reset_on_run()
|
||||
|
||||
def reset_on_restart(self) -> None:
|
||||
self.step.reset_on_restart()
|
||||
|
@ -257,8 +260,8 @@ class OptimizationProgress(BaseProgress):
|
|||
def optimizer_steps(self) -> int:
|
||||
return self.optimizer.step.total.completed
|
||||
|
||||
def reset_on_epoch(self) -> None:
|
||||
self.optimizer.reset_on_epoch()
|
||||
def reset_on_run(self) -> None:
|
||||
self.optimizer.reset_on_run()
|
||||
|
||||
def reset_on_restart(self) -> None:
|
||||
self.optimizer.reset_on_restart()
|
||||
|
|
Loading…
Reference in New Issue