Rename `reset_on_epoch` to `reset_on_run` (#9658)

This commit is contained in:
Carlos Mocholí 2021-09-25 04:27:54 +02:00 committed by GitHub
parent 8fcdcb598b
commit d02fc2b728
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 21 additions and 17 deletions

View File

@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
* Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561))
* Use `completed` over `processed` in `reset_on_restart` ([#9656](https://github.com/PyTorchLightning/pytorch-lightning/pull/9656))
* Rename `reset_on_epoch` to `reset_on_run` ([#9658](https://github.com/PyTorchLightning/pytorch-lightning/pull/9658))
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))

View File

@ -56,9 +56,9 @@ class DataLoaderLoop(Loop):
def reset(self) -> None:
"""Resets the internal state."""
if not self.restarting:
self.dataloader_progress.current.reset()
self.dataloader_progress.reset_on_run()
else:
self.dataloader_progress.current.reset_on_restart()
self.dataloader_progress.reset_on_restart()
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
self.dataloader_progress.increment_ready()

View File

@ -58,9 +58,9 @@ class EvaluationEpochLoop(Loop):
self.outputs = []
if not self.restarting:
self.batch_progress.current.reset()
self.batch_progress.reset_on_run()
else:
self.batch_progress.current.reset_on_restart()
self.batch_progress.reset_on_restart()
def on_run_start(
self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int

View File

@ -46,7 +46,7 @@ class PredictionEpochLoop(Loop):
"""Resets the loops internal state."""
self._all_batch_indices: List[int] = []
self.predictions: List[Any] = []
self.batch_progress.current.reset()
self.batch_progress.reset_on_run()
def on_run_start(
self,

View File

@ -96,12 +96,12 @@ class TrainingEpochLoop(loops.Loop):
assert self.batch_loop.optimizer_loop is not None
if self.restarting:
self.batch_progress.reset_on_restart()
self.scheduler_progress.current.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
else:
self.batch_progress.reset_on_epoch()
self.scheduler_progress.reset_on_epoch()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
self.batch_loop.optimizer_loop.optim_progress.reset_on_run()
# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]

View File

@ -175,7 +175,7 @@ class FitLoop(Loop):
def reset(self) -> None:
"""Resets the internal state of this loop."""
if self.restarting:
self.epoch_progress.current.reset_on_restart()
self.epoch_progress.reset_on_restart()
def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""

View File

@ -151,6 +151,9 @@ class Progress(BaseProgress):
def reset_on_epoch(self) -> None:
self.current.reset()
def reset_on_run(self) -> None:
self.current.reset()
def reset_on_restart(self) -> None:
self.current.reset_on_restart()
@ -188,8 +191,8 @@ class BatchProgress(Progress):
is_last_batch: bool = False
def reset_on_epoch(self) -> None:
super().reset_on_epoch()
def reset_on_run(self) -> None:
super().reset_on_run()
self.is_last_batch = False
def load_state_dict(self, state_dict: dict) -> None:
@ -224,9 +227,9 @@ class OptimizerProgress(BaseProgress):
step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))
def reset_on_epoch(self) -> None:
self.step.reset_on_epoch()
self.zero_grad.reset_on_epoch()
def reset_on_run(self) -> None:
self.step.reset_on_run()
self.zero_grad.reset_on_run()
def reset_on_restart(self) -> None:
self.step.reset_on_restart()
@ -257,8 +260,8 @@ class OptimizationProgress(BaseProgress):
def optimizer_steps(self) -> int:
return self.optimizer.step.total.completed
def reset_on_epoch(self) -> None:
self.optimizer.reset_on_epoch()
def reset_on_run(self) -> None:
self.optimizer.reset_on_run()
def reset_on_restart(self) -> None:
self.optimizer.reset_on_restart()