diff --git a/CHANGELOG.md b/CHANGELOG.md index dea6912fb5..d2988105ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,8 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking - * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598) - * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320) + * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)) + * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)) + * Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371)) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index e0401a2d7b..8f904938b5 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -98,7 +98,7 @@ class TrainingEpochLoop(loops.Loop): # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] - if not self.restarting: + if not self.restarting or self._num_training_batches_reached(): self.batch_progress.current.reset() self.scheduler_progress.current.reset() self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 98404f6511..684b3522c1 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -263,7 +263,7 @@ class FitLoop(Loop): def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() - # FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ? + # TODO: update has_completed to its proper value state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False) return state_dict diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index aeecedb915..7d4863f480 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -63,7 +63,7 @@ class OptimizerLoop(Loop): raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") def reset(self) -> None: - if not self.restarting: + if not self.restarting or self.done: self.optim_progress.optimizer_idx = 0 self.outputs = [[] for _ in range(len(self.trainer.optimizers))] diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 56c2cae14f..bb3ce98952 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -22,6 +22,7 @@ import pytest import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loops import Loop, TrainingBatchLoop from pytorch_lightning.trainer.progress import BaseProgress from tests.helpers import BoringModel @@ -513,3 +514,235 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch assert state_dict != checkpoint["loops"]["fit_loop"] assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1 assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.parametrize("n_optimizers", (1, 3, 5)) +@RunIf(min_torch="1.7.0") +def test_loop_state_on_complete_run(n_optimizers, tmpdir): + n_epochs = 3 + n_batches = 3 + accumulate_grad_batches = 1 + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + if n_optimizers > 1: + self.configure_optimizers = self.configure_optimizers_multiple + + def training_step(self, batch, batch_idx, optimizer_idx=0): + return super().training_step(batch, batch_idx) + + def configure_optimizers_multiple(self): + optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)] + + lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1) + lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1) + # no scheduler for optimizer_2 + lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] + + return optimizers, lr_schedulers + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=n_epochs, + limit_train_batches=n_batches, + limit_val_batches=0, + accumulate_grad_batches=accumulate_grad_batches, + progress_bar_refresh_rate=0, + logger=False, + checkpoint_callback=True, + ) + trainer.fit(model) + + ckpt_path = trainer.checkpoint_callback.best_model_path + assert os.path.exists(ckpt_path) + checkpoint = torch.load(ckpt_path) + + n_sch_steps_total = n_epochs + n_sch_steps_current = 1 + if n_optimizers > 1: + n_sch_steps_total = n_epochs + n_epochs * n_batches + n_sch_steps_current = n_batches + 1 + + expected = { + "state_dict": ANY, + "epoch_progress": { + "total": { + "ready": n_epochs, + "started": n_epochs, + "processed": n_epochs, + # TODO: the following "-1" offset will be fixed by + # https://github.com/PyTorchLightning/pytorch-lightning/pull/8578 + "completed": n_epochs - 1, + }, + "current": { + "ready": n_epochs, + "started": n_epochs, + "processed": n_epochs, + # TODO: the following "-1" offset will be fixed by + # https://github.com/PyTorchLightning/pytorch-lightning/pull/8578 + "completed": n_epochs - 1, + }, + }, + "epoch_loop.state_dict": ANY, + "epoch_loop.batch_progress": { + "total": { + "ready": n_epochs * n_batches, + "started": n_epochs * n_batches, + "processed": n_epochs * n_batches, + "completed": n_epochs * n_batches, + }, + "current": { + "ready": n_batches, + "started": n_batches, + "processed": n_batches, + "completed": n_batches, + }, + }, + "epoch_loop.scheduler_progress": { + "total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total}, + "current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current}, + }, + "epoch_loop.batch_loop.state_dict": ANY, + "epoch_loop.batch_loop.manual_loop.state_dict": ANY, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "optimizer_idx": n_optimizers, + "optimizer": { + "step": { + "total": { + "ready": n_epochs * n_batches * n_optimizers, + "completed": n_epochs * n_batches * n_optimizers, + }, + "current": { + "ready": n_batches * n_optimizers, + "completed": n_batches * n_optimizers, + }, + }, + "zero_grad": { + "total": { + "ready": n_epochs * n_batches * n_optimizers, + "started": n_epochs * n_batches * n_optimizers, + "completed": n_epochs * n_batches * n_optimizers, + }, + "current": { + "ready": n_batches * n_optimizers, + "started": n_batches * n_optimizers, + "completed": n_batches * n_optimizers, + }, + }, + }, + }, + "epoch_loop.val_loop.state_dict": ANY, + "epoch_loop.val_loop.dataloader_progress": ANY, + "epoch_loop.val_loop.epoch_loop.state_dict": ANY, + "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, + "epoch_loop.val_loop._results": ANY, + "epoch_loop._results": ANY, + } + assert checkpoint["loops"]["fit_loop"] == expected + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@RunIf(min_torch="1.7.0") +def test_fit_loop_reset(tmpdir): + """Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed + loop or from a mid-epoch checkpoint.""" + + # generate checkpoints at end of epoch and mid-epoch + model = BoringModel() + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + every_n_train_steps=2, + save_top_k=-1, + ) + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=4, + num_sanity_val_steps=0, + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + weights_summary=None, + ) + trainer.fit(model) + + # reset state loaded from a checkpoint from mid-epoch + mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt")) + fit_loop = trainer.fit_loop + epoch_loop = fit_loop.epoch_loop + optimizer_loop = epoch_loop.batch_loop.optimizer_loop + assert not fit_loop.restarting + assert not epoch_loop.restarting + assert not optimizer_loop.restarting + + fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"]) + + def mid_epoch_reset_assertions(): + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch + assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.restarting + assert epoch_loop.batch_progress.total.ready == 2 + assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 2 + assert epoch_loop.batch_progress.current.completed == 2 + + # resetting from a mid-epoch checkpoint should not change progress counters + mid_epoch_reset_assertions() + assert optimizer_loop.optim_progress.optimizer_idx == 1 + fit_loop.reset() + epoch_loop.reset() + optimizer_loop.reset() + mid_epoch_reset_assertions() + assert optimizer_loop.optim_progress.optimizer_idx == 0 + + # reset state loaded from a checkpoint from the end of an epoch + end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt")) + fit_loop = trainer.fit_loop + epoch_loop = fit_loop.epoch_loop + fit_loop.restarting = False + epoch_loop.restarting = False + optimizer_loop.restarting = False + + fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"]) + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes + assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.restarting + assert epoch_loop.batch_progress.total.ready == 4 + assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 4 + assert epoch_loop.batch_progress.current.completed == 4 + + assert optimizer_loop.optim_progress.optimizer_idx == 1 + + # resetting from a end-of-epoch checkpoint should reset the current counters to 0 + fit_loop.reset() + epoch_loop.reset() + optimizer_loop.reset() + + assert fit_loop.restarting + assert fit_loop.epoch_progress.total.ready == 1 + assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes + assert fit_loop.epoch_progress.current.ready == 0 + assert fit_loop.epoch_progress.current.completed == 0 + + assert epoch_loop.restarting + assert epoch_loop.batch_progress.total.ready == 4 + assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end + assert epoch_loop.batch_progress.current.ready == 0 + assert epoch_loop.batch_progress.current.completed == 0 + + assert optimizer_loop.optim_progress.optimizer_idx == 0