Fixes resuming checkpoints rerunning last epoch (#866)
* Properly restore current epoch and global step on resume * Add test * Move increment to saving rather than loading * Fix other tests that refer to current epoch * Formatting * Add warning for mid-epoch resuming * Formatting * Fix warning check for accumulated batches * Add variable to init * Formatting * Add check for 0 training steps * Make check more readable
This commit is contained in:
parent
2b5458e852
commit
6e7dc9c236
|
@ -125,6 +125,8 @@ class TrainerIOMixin(ABC):
|
|||
self.early_stop_callback = None
|
||||
self.lr_schedulers = None
|
||||
self.optimizers = None
|
||||
self.num_training_batches = None
|
||||
self.accumulate_grad_batches = None
|
||||
|
||||
def get_model(self):
|
||||
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel,
|
||||
|
@ -305,10 +307,9 @@ class TrainerIOMixin(ABC):
|
|||
self.restore_training_state(checkpoint)
|
||||
|
||||
def dump_checkpoint(self):
|
||||
|
||||
checkpoint = {
|
||||
'epoch': self.current_epoch,
|
||||
'global_step': self.global_step
|
||||
'epoch': self.current_epoch + 1,
|
||||
'global_step': self.global_step + 1,
|
||||
}
|
||||
|
||||
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
|
||||
|
@ -388,6 +389,17 @@ class TrainerIOMixin(ABC):
|
|||
self.global_step = checkpoint['global_step']
|
||||
self.current_epoch = checkpoint['epoch']
|
||||
|
||||
# Division deals with global step stepping once per accumulated batch
|
||||
# Inequality deals with different global step for odd vs even num_training_batches
|
||||
n_accum = 1 if self.accumulate_grad_batches is None else self.accumulate_grad_batches
|
||||
expected_steps = self.num_training_batches / n_accum
|
||||
if self.num_training_batches != 0 and self.global_step % expected_steps > 1:
|
||||
warnings.warn(
|
||||
"You're resuming from a checkpoint that ended mid-epoch. "
|
||||
"This can cause unreliable results if further training is done, "
|
||||
"consider using an end of epoch checkpoint. "
|
||||
)
|
||||
|
||||
# restore the optimizers
|
||||
optimizer_states = checkpoint['optimizer_states']
|
||||
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
|
||||
|
|
|
@ -214,8 +214,8 @@ def test_dp_resume(tmpdir):
|
|||
trainer.is_slurm_managing_tasks = True
|
||||
result = trainer.fit(model)
|
||||
|
||||
# track epoch before saving
|
||||
real_global_epoch = trainer.current_epoch
|
||||
# track epoch before saving. Increment since we finished the current epoch, don't want to rerun
|
||||
real_global_epoch = trainer.current_epoch + 1
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'amp + dp model failed to complete'
|
||||
|
@ -282,7 +282,8 @@ def test_cpu_restore_training(tmpdir):
|
|||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
real_global_epoch = trainer.current_epoch
|
||||
# Increment since we've finished the current epoch, don't want to rerun
|
||||
real_global_epoch = trainer.current_epoch + 1
|
||||
|
||||
# traning complete
|
||||
assert result == 1, 'amp + ddp model failed to complete'
|
||||
|
|
|
@ -413,6 +413,74 @@ def test_multiple_val_dataloader(tmpdir):
|
|||
tutils.run_prediction(dataloader, trainer.model)
|
||||
|
||||
|
||||
def test_resume_from_checkpoint_epoch_restored(tmpdir):
|
||||
"""Verify resuming from checkpoint runs the right number of epochs"""
|
||||
import types
|
||||
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = tutils.get_hparams()
|
||||
|
||||
def new_model():
|
||||
# Create a model that tracks epochs and batches seen
|
||||
model = LightningTestModel(hparams)
|
||||
model.num_epochs_seen = 0
|
||||
model.num_batches_seen = 0
|
||||
|
||||
def increment_epoch(self):
|
||||
self.num_epochs_seen += 1
|
||||
|
||||
def increment_batch(self, _):
|
||||
self.num_batches_seen += 1
|
||||
|
||||
# Bind the increment_epoch function on_epoch_end so that the
|
||||
# model keeps track of the number of epochs it has seen.
|
||||
model.on_epoch_end = types.MethodType(increment_epoch, model)
|
||||
model.on_batch_start = types.MethodType(increment_batch, model)
|
||||
return model
|
||||
|
||||
model = new_model()
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_epochs=2,
|
||||
train_percent_check=0.65,
|
||||
val_percent_check=1,
|
||||
checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1),
|
||||
logger=False,
|
||||
default_save_path=tmpdir,
|
||||
early_stop_callback=False,
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
trainer.fit(model)
|
||||
|
||||
training_batches = trainer.num_training_batches
|
||||
|
||||
assert model.num_epochs_seen == 2
|
||||
assert model.num_batches_seen == training_batches * 2
|
||||
|
||||
# Other checkpoints can be uncommented if/when resuming mid-epoch is supported
|
||||
checkpoints = [
|
||||
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt"),
|
||||
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0_v0.ckpt"),
|
||||
# os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt"),
|
||||
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1_v0.ckpt"),
|
||||
]
|
||||
|
||||
for check in checkpoints:
|
||||
next_model = new_model()
|
||||
state = torch.load(check)
|
||||
|
||||
# Resume training
|
||||
trainer_options['max_epochs'] = 4
|
||||
new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check)
|
||||
new_trainer.fit(next_model)
|
||||
assert state['global_step'] + next_model.num_batches_seen == training_batches * 4
|
||||
|
||||
|
||||
def test_multiple_test_dataloader(tmpdir):
|
||||
"""Verify multiple test_dataloader."""
|
||||
tutils.reset_seed()
|
||||
|
|
Loading…
Reference in New Issue