diff --git a/CHANGELOG.md b/CHANGELOG.md index e5a4becf40..2701344a4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) * Call `reset_on_restart` in the loop's `reset` hook instead of when loading a checkpoint ([#9561](https://github.com/PyTorchLightning/pytorch-lightning/pull/9561)) + * Use `completed` over `processed` in `reset_on_restart` ([#9656](https://github.com/PyTorchLightning/pytorch-lightning/pull/9656)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 0f07c61999..6c2d95c6b8 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -103,9 +103,8 @@ class ProcessedTracker(StartedTracker): self.processed = 0 def reset_on_restart(self) -> None: - # use `processed` in this case as the reset value - self.completed = self.processed super().reset_on_restart() + self.processed = self.completed @dataclass @@ -149,13 +148,13 @@ class Progress(BaseProgress): """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) + def reset_on_restart(self) -> None: + self.current.reset_on_restart() + def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) self.current.load_state_dict(state_dict["current"]) - def reset_on_restart(self) -> None: - self.current.reset_on_restart() - @dataclass class DataLoaderProgress(Progress): @@ -201,6 +200,10 @@ class OptimizerProgress(BaseProgress): self.step.current.reset() self.zero_grad.current.reset() + def reset_on_restart(self) -> None: + self.step.reset_on_restart() + self.zero_grad.reset_on_restart() + def load_state_dict(self, state_dict: dict) -> None: self.step.load_state_dict(state_dict["step"]) self.zero_grad.load_state_dict(state_dict["zero_grad"]) @@ -229,10 +232,9 @@ class OptimizationProgress(BaseProgress): def reset_on_epoch(self) -> None: self.optimizer.reset_on_epoch() + def reset_on_restart(self) -> None: + self.optimizer.reset_on_restart() + def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) self.optimizer_position = state_dict["optimizer_position"] - - def reset_on_restart(self) -> None: - self.optimizer.step.current.reset_on_restart() - self.optimizer.zero_grad.current.reset_on_restart() diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 47145a2f8f..a9a24d1638 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -700,9 +700,11 @@ def test_fit_loop_reset(tmpdir): assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.processed == 2 assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 2 - assert epoch_loop.batch_progress.current.completed == 2 + assert epoch_loop.batch_progress.current.ready == 1 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 1 + assert epoch_loop.batch_progress.current.completed == 1 assert optimizer_loop.restarting assert optimizer_loop.optim_progress.optimizer_position == 1 @@ -730,8 +732,10 @@ def test_fit_loop_reset(tmpdir): assert epoch_loop.restarting assert epoch_loop.batch_progress.total.ready == 4 + assert epoch_loop.batch_progress.total.processed == 4 assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end - assert epoch_loop.batch_progress.current.ready == 0 - assert epoch_loop.batch_progress.current.completed == 0 + assert epoch_loop.batch_progress.current.ready == 3 # currents get set to the completed value + assert epoch_loop.batch_progress.current.processed == 3 + assert epoch_loop.batch_progress.current.completed == 3 assert optimizer_loop.optim_progress.optimizer_position == 1 diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 9cf9064025..2dc8ca2c91 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -38,7 +38,7 @@ def test_tracker_reset_on_restart(): t = ProcessedTracker(ready=4, started=4, processed=3, completed=2) t.reset_on_restart() - assert t == ProcessedTracker(ready=3, started=3, processed=3, completed=3) + assert t == ProcessedTracker(ready=2, started=2, processed=2, completed=2) @pytest.mark.parametrize("attr", ("ready", "started", "processed", "completed"))