early epoch stopping

This commit is contained in:
William Falcon 2019-04-23 08:26:48 -04:00
parent 676d76d839
commit 1961a6abb2
1 changed files with 6 additions and 2 deletions

View File

@ -270,7 +270,9 @@ class Trainer(TrainerIO):
# ---------------
# RUN TRAIN STEP
# ---------------
self.__run_tng_batch(data_batch)
batch_result = self.__run_tng_batch(data_batch)
if batch_result == -1:
break
# ---------------
# RUN VAL STEP
@ -330,7 +332,7 @@ class Trainer(TrainerIO):
if self.__is_function_implemented('on_batch_start'):
response = self.model.on_batch_start(data_batch)
if response == -1:
return
return -1
if self.enable_tqdm:
self.prog_bar.update(1)
@ -372,6 +374,8 @@ class Trainer(TrainerIO):
if self.__is_function_implemented('on_batch_end'):
self.model.on_batch_end()
return 0
def __run_validation(self):
# decide if can check epochs
can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0