diff --git a/src/pytorch-lightning/models/trainer.py b/src/pytorch-lightning/models/trainer.py index 34c9e27b18..1088225f46 100644 --- a/src/pytorch-lightning/models/trainer.py +++ b/src/pytorch-lightning/models/trainer.py @@ -266,6 +266,10 @@ class Trainer(TrainerIO): if met_batch_limit: break + # give model a chance to end epoch early + if self.model.should_stop_epoch(): + break + # --------------- # RUN TRAIN STEP # --------------- diff --git a/src/pytorch-lightning/root_module/hooks.py b/src/pytorch-lightning/root_module/hooks.py index 3517f571db..e5a62f0115 100644 --- a/src/pytorch-lightning/root_module/hooks.py +++ b/src/pytorch-lightning/root_module/hooks.py @@ -18,3 +18,6 @@ class ModelHooks(torch.nn.Module): def on_post_performance_check(self): pass + + def should_stop_epoch(self): + return False