early epoch stopping

This commit is contained in:
William Falcon 2019-04-23 08:46:20 -04:00
parent ffd6dc678c
commit 95aee7ff96
1 changed files with 5 additions and 3 deletions

View File

@ -271,14 +271,13 @@ class Trainer(TrainerIO):
# RUN TRAIN STEP # RUN TRAIN STEP
# --------------- # ---------------
batch_result = self.__run_tng_batch(data_batch) batch_result = self.__run_tng_batch(data_batch)
if batch_result == -1: early_stop_epoch = batch_result == -1
break
# --------------- # ---------------
# RUN VAL STEP # RUN VAL STEP
# --------------- # ---------------
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0 is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
if self.fast_dev_run or is_val_check_batch: if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
self.__run_validation() self.__run_validation()
# when batch should be saved # when batch should be saved
@ -310,6 +309,9 @@ class Trainer(TrainerIO):
if self.__is_function_implemented('on_batch_end'): if self.__is_function_implemented('on_batch_end'):
self.model.on_batch_end() self.model.on_batch_end()
if early_stop_epoch:
break
# hook # hook
if self.__is_function_implemented('on_epoch_end'): if self.__is_function_implemented('on_epoch_end'):
self.model.on_epoch_end() self.model.on_epoch_end()