force crash when max_epochs < epochs in a checkpoint (#3580)

* force crash when max_epochs < epochs in a checkpoint

* force crash when max_epochs < epochs in a checkpoint
This commit is contained in:
William Falcon 2020-09-20 22:04:22 -04:00 committed by GitHub
parent a71d62d840
commit 277538970d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 1 deletions

View File

@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from apex import amp
@ -145,6 +146,14 @@ class CheckpointConnector:
self.trainer.global_step = checkpoint['global_step']
self.trainer.current_epoch = checkpoint['epoch']
# crash if max_epochs is lower than the current epoch from the checkpoint
if self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
but the Trainer(max_epochs={self.trainer.max_epochs})
"""
raise MisconfigurationException(m)
# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches

View File

@ -9,6 +9,7 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class EarlyStoppingTestRestore(EarlyStopping):
@ -63,6 +64,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
resume_from_checkpoint=checkpoint_filepath,
early_stop_callback=early_stop_callback,
)
with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'):
new_trainer.fit(model)