Early stopping (#332)

* callbacks use all other keys in return dict

* callbacks use all other keys in return dict

* callbacks use all other keys in return dict

* callbacks use all other keys in return dict

* remove os.exit from early stopping
This commit is contained in:
William Falcon 2019-10-08 16:21:00 -04:00 committed by GitHub
parent 6e3e740a7f
commit dcaba55251
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 9 deletions

View File

@ -125,7 +125,8 @@ class EarlyStopping(Callback):
print('Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning)
exit(-1)
stop_training = True
return stop_training
if self.monitor_op(current - self.min_delta, self.best):
self.best = current

View File

@ -148,6 +148,7 @@ class Trainer(TrainerIO):
self.avg_loss = 0
self.batch_nb = 0
self.tqdm_metrics = {}
self.callback_metrics = {}
self.nb_val_batches = 0
self.nb_training_batches = 0
self.nb_test_batches = 0
@ -1061,7 +1062,7 @@ class Trainer(TrainerIO):
met_min_epochs = epoch_nb > self.min_nb_epochs
if self.enable_early_stop and met_min_epochs:
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb,
logs=self.__training_tqdm_dict)
logs=self.callback_metrics)
# stop training
stop = should_stop and met_min_epochs
if stop:
@ -1237,8 +1238,9 @@ class Trainer(TrainerIO):
output = self.model.training_step(*args)
# format and reduce outputs accordingly
loss, progress_bar_metrics, log_metrics = self.__process_output(output, train=True)
return loss, progress_bar_metrics, log_metrics
output = self.__process_output(output, train=True)
loss, progress_bar_metrics, log_metrics, callback_metrics = output
return loss, progress_bar_metrics, log_metrics, callback_metrics
def __process_output(self, output, train=False):
"""
@ -1247,6 +1249,12 @@ class Trainer(TrainerIO):
:param output:
:return:
"""
# all keys not progress_bar or log are candidates for callbacks
callback_metrics = {}
for k, v in output.items():
if k not in ['progress_bar', 'log']:
callback_metrics[k] = v
try:
progress_output = output['progress_bar']
@ -1293,7 +1301,7 @@ class Trainer(TrainerIO):
if self.use_dp or self.use_ddp2:
loss = reduce_distributed_output(loss, self.num_gpus)
return loss, progress_bar_metrics, log_metrics
return loss, progress_bar_metrics, log_metrics, callback_metrics
def __clip_gradients(self):
if self.gradient_clip_val > 0:
@ -1310,6 +1318,9 @@ class Trainer(TrainerIO):
# track grad norms
grad_norm_dic = {}
# track all metrics for callbacks
all_callback_metrics = []
# track metrics to log
all_log_metrics = []
@ -1334,7 +1345,10 @@ class Trainer(TrainerIO):
def optimizer_closure():
# forward pass
output = self.__training_forward(batch, batch_nb, opt_idx)
closure_loss, progress_bar_metrics, log_metrics = output
closure_loss, progress_bar_metrics, log_metrics, callback_metrics = output
# track metrics for callbacks
all_callback_metrics.append(callback_metrics)
# track progress bar metrics
self.__add_tqdm_metrics(progress_bar_metrics)
@ -1404,6 +1418,10 @@ class Trainer(TrainerIO):
# collapse all metrics into one dict
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
# track all metrics for callbacks
self.callback_metrics = {k: v for d in all_callback_metrics for k, v in d.items()}
return 0, grad_norm_dic, all_log_metrics
def __run_evaluation(self, test=False):
@ -1444,14 +1462,17 @@ class Trainer(TrainerIO):
dataloaders,
max_batches,
test)
_, progress_bar_metrics, log_metrics = self.__process_output(eval_results)
_, prog_bar_metrics, log_metrics, callback_metrics = self.__process_output(eval_results)
# add metrics to prog bar
self.__add_tqdm_metrics(progress_bar_metrics)
self.__add_tqdm_metrics(prog_bar_metrics)
# log metrics
self.__log_metrics(log_metrics, {})
# track metrics for callbacks
self.callback_metrics = callback_metrics
# hook
model.on_post_performance_check()
@ -1463,4 +1484,4 @@ class Trainer(TrainerIO):
# model checkpointing
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
logs=self.__training_tqdm_dict)
logs=self.callback_metrics)