fix resuming from checkpoint for fault-tolerant in case of no failure (#9371)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
7ca038b83f
commit
6ff43cbff7
|
@ -23,8 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
|
||||
- Progress tracking
|
||||
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
|
||||
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)
|
||||
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
|
||||
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
|
||||
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
|
||||
|
||||
|
||||
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
|
||||
|
|
|
@ -98,7 +98,7 @@ class TrainingEpochLoop(loops.Loop):
|
|||
# track epoch output
|
||||
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
|
||||
|
||||
if not self.restarting:
|
||||
if not self.restarting or self._num_training_batches_reached():
|
||||
self.batch_progress.current.reset()
|
||||
self.scheduler_progress.current.reset()
|
||||
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()
|
||||
|
|
|
@ -263,7 +263,7 @@ class FitLoop(Loop):
|
|||
|
||||
def on_save_checkpoint(self) -> Dict:
|
||||
state_dict = super().on_save_checkpoint()
|
||||
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
|
||||
# TODO: update has_completed to its proper value
|
||||
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
|
||||
return state_dict
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ class OptimizerLoop(Loop):
|
|||
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
|
||||
|
||||
def reset(self) -> None:
|
||||
if not self.restarting:
|
||||
if not self.restarting or self.done:
|
||||
self.optim_progress.optimizer_idx = 0
|
||||
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loops import Loop, TrainingBatchLoop
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from tests.helpers import BoringModel
|
||||
|
@ -513,3 +514,235 @@ def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch
|
|||
assert state_dict != checkpoint["loops"]["fit_loop"]
|
||||
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
|
||||
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
|
||||
@RunIf(min_torch="1.7.0")
|
||||
def test_loop_state_on_complete_run(n_optimizers, tmpdir):
|
||||
n_epochs = 3
|
||||
n_batches = 3
|
||||
accumulate_grad_batches = 1
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if n_optimizers > 1:
|
||||
self.configure_optimizers = self.configure_optimizers_multiple
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=0):
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def configure_optimizers_multiple(self):
|
||||
optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)]
|
||||
|
||||
lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1)
|
||||
lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1)
|
||||
# no scheduler for optimizer_2
|
||||
lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}]
|
||||
|
||||
return optimizers, lr_schedulers
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=n_epochs,
|
||||
limit_train_batches=n_batches,
|
||||
limit_val_batches=0,
|
||||
accumulate_grad_batches=accumulate_grad_batches,
|
||||
progress_bar_refresh_rate=0,
|
||||
logger=False,
|
||||
checkpoint_callback=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
ckpt_path = trainer.checkpoint_callback.best_model_path
|
||||
assert os.path.exists(ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path)
|
||||
|
||||
n_sch_steps_total = n_epochs
|
||||
n_sch_steps_current = 1
|
||||
if n_optimizers > 1:
|
||||
n_sch_steps_total = n_epochs + n_epochs * n_batches
|
||||
n_sch_steps_current = n_batches + 1
|
||||
|
||||
expected = {
|
||||
"state_dict": ANY,
|
||||
"epoch_progress": {
|
||||
"total": {
|
||||
"ready": n_epochs,
|
||||
"started": n_epochs,
|
||||
"processed": n_epochs,
|
||||
# TODO: the following "-1" offset will be fixed by
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
|
||||
"completed": n_epochs - 1,
|
||||
},
|
||||
"current": {
|
||||
"ready": n_epochs,
|
||||
"started": n_epochs,
|
||||
"processed": n_epochs,
|
||||
# TODO: the following "-1" offset will be fixed by
|
||||
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
|
||||
"completed": n_epochs - 1,
|
||||
},
|
||||
},
|
||||
"epoch_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_progress": {
|
||||
"total": {
|
||||
"ready": n_epochs * n_batches,
|
||||
"started": n_epochs * n_batches,
|
||||
"processed": n_epochs * n_batches,
|
||||
"completed": n_epochs * n_batches,
|
||||
},
|
||||
"current": {
|
||||
"ready": n_batches,
|
||||
"started": n_batches,
|
||||
"processed": n_batches,
|
||||
"completed": n_batches,
|
||||
},
|
||||
},
|
||||
"epoch_loop.scheduler_progress": {
|
||||
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
|
||||
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
|
||||
},
|
||||
"epoch_loop.batch_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
|
||||
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
|
||||
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
|
||||
"optimizer_idx": n_optimizers,
|
||||
"optimizer": {
|
||||
"step": {
|
||||
"total": {
|
||||
"ready": n_epochs * n_batches * n_optimizers,
|
||||
"completed": n_epochs * n_batches * n_optimizers,
|
||||
},
|
||||
"current": {
|
||||
"ready": n_batches * n_optimizers,
|
||||
"completed": n_batches * n_optimizers,
|
||||
},
|
||||
},
|
||||
"zero_grad": {
|
||||
"total": {
|
||||
"ready": n_epochs * n_batches * n_optimizers,
|
||||
"started": n_epochs * n_batches * n_optimizers,
|
||||
"completed": n_epochs * n_batches * n_optimizers,
|
||||
},
|
||||
"current": {
|
||||
"ready": n_batches * n_optimizers,
|
||||
"started": n_batches * n_optimizers,
|
||||
"completed": n_batches * n_optimizers,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"epoch_loop.val_loop.state_dict": ANY,
|
||||
"epoch_loop.val_loop.dataloader_progress": ANY,
|
||||
"epoch_loop.val_loop.epoch_loop.state_dict": ANY,
|
||||
"epoch_loop.val_loop.epoch_loop.batch_progress": ANY,
|
||||
"epoch_loop.val_loop._results": ANY,
|
||||
"epoch_loop._results": ANY,
|
||||
}
|
||||
assert checkpoint["loops"]["fit_loop"] == expected
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
@RunIf(min_torch="1.7.0")
|
||||
def test_fit_loop_reset(tmpdir):
|
||||
"""Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed
|
||||
loop or from a mid-epoch checkpoint."""
|
||||
|
||||
# generate checkpoints at end of epoch and mid-epoch
|
||||
model = BoringModel()
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
dirpath=tmpdir,
|
||||
every_n_train_steps=2,
|
||||
save_top_k=-1,
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=4,
|
||||
num_sanity_val_steps=0,
|
||||
max_epochs=2,
|
||||
callbacks=[checkpoint_callback],
|
||||
logger=False,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# reset state loaded from a checkpoint from mid-epoch
|
||||
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt"))
|
||||
fit_loop = trainer.fit_loop
|
||||
epoch_loop = fit_loop.epoch_loop
|
||||
optimizer_loop = epoch_loop.batch_loop.optimizer_loop
|
||||
assert not fit_loop.restarting
|
||||
assert not epoch_loop.restarting
|
||||
assert not optimizer_loop.restarting
|
||||
|
||||
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])
|
||||
|
||||
def mid_epoch_reset_assertions():
|
||||
assert fit_loop.restarting
|
||||
assert fit_loop.epoch_progress.total.ready == 1
|
||||
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
|
||||
assert fit_loop.epoch_progress.current.ready == 0
|
||||
assert fit_loop.epoch_progress.current.completed == 0
|
||||
|
||||
assert epoch_loop.restarting
|
||||
assert epoch_loop.batch_progress.total.ready == 2
|
||||
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
|
||||
assert epoch_loop.batch_progress.current.ready == 2
|
||||
assert epoch_loop.batch_progress.current.completed == 2
|
||||
|
||||
# resetting from a mid-epoch checkpoint should not change progress counters
|
||||
mid_epoch_reset_assertions()
|
||||
assert optimizer_loop.optim_progress.optimizer_idx == 1
|
||||
fit_loop.reset()
|
||||
epoch_loop.reset()
|
||||
optimizer_loop.reset()
|
||||
mid_epoch_reset_assertions()
|
||||
assert optimizer_loop.optim_progress.optimizer_idx == 0
|
||||
|
||||
# reset state loaded from a checkpoint from the end of an epoch
|
||||
end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt"))
|
||||
fit_loop = trainer.fit_loop
|
||||
epoch_loop = fit_loop.epoch_loop
|
||||
fit_loop.restarting = False
|
||||
epoch_loop.restarting = False
|
||||
optimizer_loop.restarting = False
|
||||
|
||||
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])
|
||||
|
||||
assert fit_loop.restarting
|
||||
assert fit_loop.epoch_progress.total.ready == 1
|
||||
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
|
||||
assert fit_loop.epoch_progress.current.ready == 0
|
||||
assert fit_loop.epoch_progress.current.completed == 0
|
||||
|
||||
assert epoch_loop.restarting
|
||||
assert epoch_loop.batch_progress.total.ready == 4
|
||||
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
|
||||
assert epoch_loop.batch_progress.current.ready == 4
|
||||
assert epoch_loop.batch_progress.current.completed == 4
|
||||
|
||||
assert optimizer_loop.optim_progress.optimizer_idx == 1
|
||||
|
||||
# resetting from a end-of-epoch checkpoint should reset the current counters to 0
|
||||
fit_loop.reset()
|
||||
epoch_loop.reset()
|
||||
optimizer_loop.reset()
|
||||
|
||||
assert fit_loop.restarting
|
||||
assert fit_loop.epoch_progress.total.ready == 1
|
||||
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
|
||||
assert fit_loop.epoch_progress.current.ready == 0
|
||||
assert fit_loop.epoch_progress.current.completed == 0
|
||||
|
||||
assert epoch_loop.restarting
|
||||
assert epoch_loop.batch_progress.total.ready == 4
|
||||
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
|
||||
assert epoch_loop.batch_progress.current.ready == 0
|
||||
assert epoch_loop.batch_progress.current.completed == 0
|
||||
|
||||
assert optimizer_loop.optim_progress.optimizer_idx == 0
|
||||
|
|
Loading…
Reference in New Issue