fix resuming from checkpoint for fault-tolerant in case of no failure (#9371)

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-09-10 19:25:46 +02:00 committed by GitHub
parent 7ca038b83f
commit 6ff43cbff7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 239 additions and 5 deletions

View File

@ -23,8 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))

View File

@ -98,7 +98,7 @@ class TrainingEpochLoop(loops.Loop):
# track epoch output
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
if not self.restarting:
if not self.restarting or self._num_training_batches_reached():
self.batch_progress.current.reset()
self.scheduler_progress.current.reset()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()

View File

@ -263,7 +263,7 @@ class FitLoop(Loop):
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
# TODO: update has_completed to its proper value
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
return state_dict

View File

@ -63,7 +63,7 @@ class OptimizerLoop(Loop):
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
def reset(self) -> None:
if not self.restarting:
if not self.restarting or self.done:
self.optim_progress.optimizer_idx = 0
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]

View File

@ -22,6 +22,7 @@ import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loops import Loop, TrainingBatchLoop
from pytorch_lightning.trainer.progress import BaseProgress
from tests.helpers import BoringModel
@ -513,3 +514,235 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
assert state_dict != checkpoint["loops"]["fit_loop"]
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@RunIf(min_torch="1.7.0")
def test_loop_state_on_complete_run(n_optimizers, tmpdir):
n_epochs = 3
n_batches = 3
accumulate_grad_batches = 1
class TestModel(BoringModel):
def __init__(self):
super().__init__()
if n_optimizers > 1:
self.configure_optimizers = self.configure_optimizers_multiple
def training_step(self, batch, batch_idx, optimizer_idx=0):
return super().training_step(batch, batch_idx)
def configure_optimizers_multiple(self):
optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)]
lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1)
lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1)
# no scheduler for optimizer_2
lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}]
return optimizers, lr_schedulers
model = TestModel()
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
accumulate_grad_batches=accumulate_grad_batches,
progress_bar_refresh_rate=0,
logger=False,
checkpoint_callback=True,
)
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path
assert os.path.exists(ckpt_path)
checkpoint = torch.load(ckpt_path)
n_sch_steps_total = n_epochs
n_sch_steps_current = 1
if n_optimizers > 1:
n_sch_steps_total = n_epochs + n_epochs * n_batches
n_sch_steps_current = n_batches + 1
expected = {
"state_dict": ANY,
"epoch_progress": {
"total": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
"current": {
"ready": n_epochs,
"started": n_epochs,
"processed": n_epochs,
# TODO: the following "-1" offset will be fixed by
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
"completed": n_epochs - 1,
},
},
"epoch_loop.state_dict": ANY,
"epoch_loop.batch_progress": {
"total": {
"ready": n_epochs * n_batches,
"started": n_epochs * n_batches,
"processed": n_epochs * n_batches,
"completed": n_epochs * n_batches,
},
"current": {
"ready": n_batches,
"started": n_batches,
"processed": n_batches,
"completed": n_batches,
},
},
"epoch_loop.scheduler_progress": {
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_idx": n_optimizers,
"optimizer": {
"step": {
"total": {
"ready": n_epochs * n_batches * n_optimizers,
"completed": n_epochs * n_batches * n_optimizers,
},
"current": {
"ready": n_batches * n_optimizers,
"completed": n_batches * n_optimizers,
},
},
"zero_grad": {
"total": {
"ready": n_epochs * n_batches * n_optimizers,
"started": n_epochs * n_batches * n_optimizers,
"completed": n_epochs * n_batches * n_optimizers,
},
"current": {
"ready": n_batches * n_optimizers,
"started": n_batches * n_optimizers,
"completed": n_batches * n_optimizers,
},
},
},
},
"epoch_loop.val_loop.state_dict": ANY,
"epoch_loop.val_loop.dataloader_progress": ANY,
"epoch_loop.val_loop.epoch_loop.state_dict": ANY,
"epoch_loop.val_loop.epoch_loop.batch_progress": ANY,
"epoch_loop.val_loop._results": ANY,
"epoch_loop._results": ANY,
}
assert checkpoint["loops"]["fit_loop"] == expected
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@RunIf(min_torch="1.7.0")
def test_fit_loop_reset(tmpdir):
"""Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed
loop or from a mid-epoch checkpoint."""
# generate checkpoints at end of epoch and mid-epoch
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
every_n_train_steps=2,
save_top_k=-1,
)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=4,
num_sanity_val_steps=0,
max_epochs=2,
callbacks=[checkpoint_callback],
logger=False,
weights_summary=None,
)
trainer.fit(model)
# reset state loaded from a checkpoint from mid-epoch
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
optimizer_loop = epoch_loop.batch_loop.optimizer_loop
assert not fit_loop.restarting
assert not epoch_loop.restarting
assert not optimizer_loop.restarting
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])
def mid_epoch_reset_assertions():
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 2
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 2
assert epoch_loop.batch_progress.current.completed == 2
# resetting from a mid-epoch checkpoint should not change progress counters
mid_epoch_reset_assertions()
assert optimizer_loop.optim_progress.optimizer_idx == 1
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()
mid_epoch_reset_assertions()
assert optimizer_loop.optim_progress.optimizer_idx == 0
# reset state loaded from a checkpoint from the end of an epoch
end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt"))
fit_loop = trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
fit_loop.restarting = False
epoch_loop.restarting = False
optimizer_loop.restarting = False
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 4
assert epoch_loop.batch_progress.current.completed == 4
assert optimizer_loop.optim_progress.optimizer_idx == 1
# resetting from a end-of-epoch checkpoint should reset the current counters to 0
fit_loop.reset()
epoch_loop.reset()
optimizer_loop.reset()
assert fit_loop.restarting
assert fit_loop.epoch_progress.total.ready == 1
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
assert fit_loop.epoch_progress.current.ready == 0
assert fit_loop.epoch_progress.current.completed == 0
assert epoch_loop.restarting
assert epoch_loop.batch_progress.total.ready == 4
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
assert epoch_loop.batch_progress.current.ready == 0
assert epoch_loop.batch_progress.current.completed == 0
assert optimizer_loop.optim_progress.optimizer_idx == 0