force crash when max_epochs < epochs in a checkpoint (#3580)
* force crash when max_epochs < epochs in a checkpoint * force crash when max_epochs < epochs in a checkpoint
This commit is contained in:
parent
a71d62d840
commit
277538970d
|
@ -33,6 +33,7 @@ from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
|||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -145,6 +146,14 @@ class CheckpointConnector:
|
|||
self.trainer.global_step = checkpoint['global_step']
|
||||
self.trainer.current_epoch = checkpoint['epoch']
|
||||
|
||||
# crash if max_epochs is lower than the current epoch from the checkpoint
|
||||
if self.trainer.current_epoch > self.trainer.max_epochs:
|
||||
m = f"""
|
||||
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
|
||||
but the Trainer(max_epochs={self.trainer.max_epochs})
|
||||
"""
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
# Division deals with global step stepping once per accumulated batch
|
||||
# Inequality deals with different global step for odd vs even num_training_batches
|
||||
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from tests.base import EvalModelTemplate
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class EarlyStoppingTestRestore(EarlyStopping):
|
||||
|
@ -63,6 +64,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
|
|||
resume_from_checkpoint=checkpoint_filepath,
|
||||
early_stop_callback=early_stop_callback,
|
||||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=r'.*you restored a checkpoint with current_epoch*'):
|
||||
new_trainer.fit(model)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue