diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index c1ba181f67..87767d7d59 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -266,10 +266,6 @@ class Trainer(TrainerIO): if met_batch_limit: break - # give model a chance to end epoch early - if self.model.should_stop_epoch(data_batch): - break - # --------------- # RUN TRAIN STEP # --------------- @@ -331,7 +327,9 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_start'): - self.model.on_batch_start() + response = self.model.on_batch_start() + if response == -1: + return if self.enable_tqdm: self.prog_bar.update(1) diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py index 02471b2952..0d903f6627 100644 --- a/pytorch_lightning/root_module/hooks.py +++ b/pytorch_lightning/root_module/hooks.py @@ -19,5 +19,3 @@ class ModelHooks(torch.nn.Module): def on_post_performance_check(self): pass - def should_stop_epoch(self, data_batch): - return False