if return -1 from a hook that loop stopps
This commit is contained in:
parent
004f015ee0
commit
e89da15f18
|
@ -266,10 +266,6 @@ class Trainer(TrainerIO):
|
|||
if met_batch_limit:
|
||||
break
|
||||
|
||||
# give model a chance to end epoch early
|
||||
if self.model.should_stop_epoch(data_batch):
|
||||
break
|
||||
|
||||
# ---------------
|
||||
# RUN TRAIN STEP
|
||||
# ---------------
|
||||
|
@ -331,7 +327,9 @@ class Trainer(TrainerIO):
|
|||
|
||||
# hook
|
||||
if self.__is_function_implemented('on_batch_start'):
|
||||
self.model.on_batch_start()
|
||||
response = self.model.on_batch_start()
|
||||
if response == -1:
|
||||
return
|
||||
|
||||
if self.enable_tqdm:
|
||||
self.prog_bar.update(1)
|
||||
|
|
|
@ -19,5 +19,3 @@ class ModelHooks(torch.nn.Module):
|
|||
def on_post_performance_check(self):
|
||||
pass
|
||||
|
||||
def should_stop_epoch(self, data_batch):
|
||||
return False
|
||||
|
|
Loading…
Reference in New Issue