early epoch stopping
This commit is contained in:
parent
ffd6dc678c
commit
95aee7ff96
|
@ -271,14 +271,13 @@ class Trainer(TrainerIO):
|
||||||
# RUN TRAIN STEP
|
# RUN TRAIN STEP
|
||||||
# ---------------
|
# ---------------
|
||||||
batch_result = self.__run_tng_batch(data_batch)
|
batch_result = self.__run_tng_batch(data_batch)
|
||||||
if batch_result == -1:
|
early_stop_epoch = batch_result == -1
|
||||||
break
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# RUN VAL STEP
|
# RUN VAL STEP
|
||||||
# ---------------
|
# ---------------
|
||||||
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
|
is_val_check_batch = (batch_nb + 1) % self.val_check_batch == 0
|
||||||
if self.fast_dev_run or is_val_check_batch:
|
if self.fast_dev_run or is_val_check_batch or early_stop_epoch:
|
||||||
self.__run_validation()
|
self.__run_validation()
|
||||||
|
|
||||||
# when batch should be saved
|
# when batch should be saved
|
||||||
|
@ -310,6 +309,9 @@ class Trainer(TrainerIO):
|
||||||
if self.__is_function_implemented('on_batch_end'):
|
if self.__is_function_implemented('on_batch_end'):
|
||||||
self.model.on_batch_end()
|
self.model.on_batch_end()
|
||||||
|
|
||||||
|
if early_stop_epoch:
|
||||||
|
break
|
||||||
|
|
||||||
# hook
|
# hook
|
||||||
if self.__is_function_implemented('on_epoch_end'):
|
if self.__is_function_implemented('on_epoch_end'):
|
||||||
self.model.on_epoch_end()
|
self.model.on_epoch_end()
|
||||||
|
|
Loading…
Reference in New Issue