From e89da15f189cdb8e82159a9167c67098b256b89d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 21 Apr 2019 13:38:50 -0400 Subject: [PATCH] if return -1 from a hook that loop stopps --- pytorch_lightning/models/trainer.py | 8 +++----- pytorch_lightning/root_module/hooks.py | 2 -- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index c1ba181f67..87767d7d59 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -266,10 +266,6 @@ class Trainer(TrainerIO): if met_batch_limit: break - # give model a chance to end epoch early - if self.model.should_stop_epoch(data_batch): - break - # --------------- # RUN TRAIN STEP # --------------- @@ -331,7 +327,9 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_start'): - self.model.on_batch_start() + response = self.model.on_batch_start() + if response == -1: + return if self.enable_tqdm: self.prog_bar.update(1) diff --git a/pytorch_lightning/root_module/hooks.py b/pytorch_lightning/root_module/hooks.py index 02471b2952..0d903f6627 100644 --- a/pytorch_lightning/root_module/hooks.py +++ b/pytorch_lightning/root_module/hooks.py @@ -19,5 +19,3 @@ class ModelHooks(torch.nn.Module): def on_post_performance_check(self): pass - def should_stop_epoch(self, data_batch): - return False