Fixes resuming checkpoints rerunning last epoch (#866)

* Properly restore current epoch and global step on resume

* Add test

* Move increment to saving rather than loading

* Fix other tests that refer to current epoch

* Formatting

* Add warning for mid-epoch resuming

* Formatting

* Fix warning check for accumulated batches

* Add variable to init

* Formatting

* Add check for 0 training steps

* Make check more readable
This commit is contained in:
Matt Painter 2020-02-22 01:27:19 +00:00 committed by GitHub
parent 2b5458e852
commit 6e7dc9c236
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 6 deletions

View File

@ -125,6 +125,8 @@ class TrainerIOMixin(ABC):
self.early_stop_callback = None
self.lr_schedulers = None
self.optimizers = None
self.num_training_batches = None
self.accumulate_grad_batches = None
def get_model(self):
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel,
@ -305,10 +307,9 @@ class TrainerIOMixin(ABC):
self.restore_training_state(checkpoint)
def dump_checkpoint(self):
checkpoint = {
'epoch': self.current_epoch,
'global_step': self.global_step
'epoch': self.current_epoch + 1,
'global_step': self.global_step + 1,
}
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
@ -388,6 +389,17 @@ class TrainerIOMixin(ABC):
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']
# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches
expected_steps = self.num_training_batches / n_accum
if self.num_training_batches != 0 and self.global_step % expected_steps > 1:
warnings.warn(
"You're resuming from a checkpoint that ended mid-epoch. "
"This can cause unreliable results if further training is done, "
"consider using an end of epoch checkpoint. "
)
# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.optimizers, optimizer_states):

View File

@ -214,8 +214,8 @@ def test_dp_resume(tmpdir):
trainer.is_slurm_managing_tasks = True
result = trainer.fit(model)
# track epoch before saving
real_global_epoch = trainer.current_epoch
# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
real_global_epoch = trainer.current_epoch + 1
# correct result and ok accuracy
assert result == 1, 'amp + dp model failed to complete'
@ -282,7 +282,8 @@ def test_cpu_restore_training(tmpdir):
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
real_global_epoch = trainer.current_epoch
# Increment since we've finished the current epoch, don't want to rerun
real_global_epoch = trainer.current_epoch + 1
# traning complete
assert result == 1, 'amp + ddp model failed to complete'

View File

@ -413,6 +413,74 @@ def test_multiple_val_dataloader(tmpdir):
tutils.run_prediction(dataloader, trainer.model)
def test_resume_from_checkpoint_epoch_restored(tmpdir):
"""Verify resuming from checkpoint runs the right number of epochs"""
import types
tutils.reset_seed()
hparams = tutils.get_hparams()
def new_model():
# Create a model that tracks epochs and batches seen
model = LightningTestModel(hparams)
model.num_epochs_seen = 0
model.num_batches_seen = 0
def increment_epoch(self):
self.num_epochs_seen += 1
def increment_batch(self, _):
self.num_batches_seen += 1
# Bind the increment_epoch function on_epoch_end so that the
# model keeps track of the number of epochs it has seen.
model.on_epoch_end = types.MethodType(increment_epoch, model)
model.on_batch_start = types.MethodType(increment_batch, model)
return model
model = new_model()
trainer_options = dict(
show_progress_bar=False,
max_epochs=2,
train_percent_check=0.65,
val_percent_check=1,
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
logger=False,
default_save_path=tmpdir,
early_stop_callback=False,
val_check_interval=0.5,
)
# fit model
trainer = Trainer(**trainer_options)
trainer.fit(model)
training_batches = trainer.num_training_batches
assert model.num_epochs_seen == 2
assert model.num_batches_seen == training_batches * 2
# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
checkpoints = [
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt"),
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0_v0.ckpt"),
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt"),
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1_v0.ckpt"),
]
for check in checkpoints:
next_model = new_model()
state = torch.load(check)
# Resume training
trainer_options['max_epochs'] = 4
new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check)
new_trainer.fit(next_model)
assert state['global_step'] + next_model.num_batches_seen == training_batches * 4
def test_multiple_test_dataloader(tmpdir):
"""Verify multiple test_dataloader."""
tutils.reset_seed()