diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 8039d03870..7bcf98b9b4 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -271,14 +271,13 @@ class Trainer(TrainerIO): # RUN TRAIN STEP # --------------- batch_result = self.__run_tng_batch(data_batch) - if batch_result == -1: - break + early_stop_epoch = batch_result == -1 # --------------- # RUN VAL STEP # --------------- is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0 - if self.fast_dev_run or is_val_check_batch: + if self.fast_dev_run or is_val_check_batch or early_stop_epoch: self.__run_validation() # when batch should be saved @@ -310,6 +309,9 @@ class Trainer(TrainerIO): if self.__is_function_implemented('on_batch_end'): self.model.on_batch_end() + if early_stop_epoch: + break + # hook if self.__is_function_implemented('on_epoch_end'): self.model.on_epoch_end()