diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 897bef1d11..1bac00a63b 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -270,7 +270,9 @@ class Trainer(TrainerIO): # --------------- # RUN TRAIN STEP # --------------- - self.__run_tng_batch(data_batch) + batch_result = self.__run_tng_batch(data_batch) + if batch_result == -1: + break # --------------- # RUN VAL STEP @@ -330,7 +332,7 @@ class Trainer(TrainerIO): if self.__is_function_implemented('on_batch_start'): response = self.model.on_batch_start(data_batch) if response == -1: - return + return -1 if self.enable_tqdm: self.prog_bar.update(1) @@ -372,6 +374,8 @@ class Trainer(TrainerIO): if self.__is_function_implemented('on_batch_end'): self.model.on_batch_end() + return 0 + def __run_validation(self): # decide if can check epochs can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0