diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 7bcf98b9b4..5edc5ad0c3 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -309,6 +309,7 @@ class Trainer(TrainerIO): if self.__is_function_implemented('on_batch_end'): self.model.on_batch_end() + # end epoch early if early_stop_epoch: break @@ -326,6 +327,7 @@ class Trainer(TrainerIO): if stop: return + def __run_tng_batch(self, data_batch): if data_batch is None: return 0