Early stopping (#332)
* callbacks use all other keys in return dict * callbacks use all other keys in return dict * callbacks use all other keys in return dict * callbacks use all other keys in return dict * remove os.exit from early stopping
This commit is contained in:
parent
6e3e740a7f
commit
dcaba55251
|
@ -125,7 +125,8 @@ class EarlyStopping(Callback):
|
|||
print('Early stopping conditioned on metric `%s` '
|
||||
'which is not available. Available metrics are: %s' %
|
||||
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning)
|
||||
exit(-1)
|
||||
stop_training = True
|
||||
return stop_training
|
||||
|
||||
if self.monitor_op(current - self.min_delta, self.best):
|
||||
self.best = current
|
||||
|
|
|
@ -148,6 +148,7 @@ class Trainer(TrainerIO):
|
|||
self.avg_loss = 0
|
||||
self.batch_nb = 0
|
||||
self.tqdm_metrics = {}
|
||||
self.callback_metrics = {}
|
||||
self.nb_val_batches = 0
|
||||
self.nb_training_batches = 0
|
||||
self.nb_test_batches = 0
|
||||
|
@ -1061,7 +1062,7 @@ class Trainer(TrainerIO):
|
|||
met_min_epochs = epoch_nb > self.min_nb_epochs
|
||||
if self.enable_early_stop and met_min_epochs:
|
||||
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb,
|
||||
logs=self.__training_tqdm_dict)
|
||||
logs=self.callback_metrics)
|
||||
# stop training
|
||||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
|
@ -1237,8 +1238,9 @@ class Trainer(TrainerIO):
|
|||
output = self.model.training_step(*args)
|
||||
|
||||
# format and reduce outputs accordingly
|
||||
loss, progress_bar_metrics, log_metrics = self.__process_output(output, train=True)
|
||||
return loss, progress_bar_metrics, log_metrics
|
||||
output = self.__process_output(output, train=True)
|
||||
loss, progress_bar_metrics, log_metrics, callback_metrics = output
|
||||
return loss, progress_bar_metrics, log_metrics, callback_metrics
|
||||
|
||||
def __process_output(self, output, train=False):
|
||||
"""
|
||||
|
@ -1247,6 +1249,12 @@ class Trainer(TrainerIO):
|
|||
:param output:
|
||||
:return:
|
||||
"""
|
||||
# all keys not progress_bar or log are candidates for callbacks
|
||||
callback_metrics = {}
|
||||
for k, v in output.items():
|
||||
if k not in ['progress_bar', 'log']:
|
||||
callback_metrics[k] = v
|
||||
|
||||
try:
|
||||
progress_output = output['progress_bar']
|
||||
|
||||
|
@ -1293,7 +1301,7 @@ class Trainer(TrainerIO):
|
|||
if self.use_dp or self.use_ddp2:
|
||||
loss = reduce_distributed_output(loss, self.num_gpus)
|
||||
|
||||
return loss, progress_bar_metrics, log_metrics
|
||||
return loss, progress_bar_metrics, log_metrics, callback_metrics
|
||||
|
||||
def __clip_gradients(self):
|
||||
if self.gradient_clip_val > 0:
|
||||
|
@ -1310,6 +1318,9 @@ class Trainer(TrainerIO):
|
|||
# track grad norms
|
||||
grad_norm_dic = {}
|
||||
|
||||
# track all metrics for callbacks
|
||||
all_callback_metrics = []
|
||||
|
||||
# track metrics to log
|
||||
all_log_metrics = []
|
||||
|
||||
|
@ -1334,7 +1345,10 @@ class Trainer(TrainerIO):
|
|||
def optimizer_closure():
|
||||
# forward pass
|
||||
output = self.__training_forward(batch, batch_nb, opt_idx)
|
||||
closure_loss, progress_bar_metrics, log_metrics = output
|
||||
closure_loss, progress_bar_metrics, log_metrics, callback_metrics = output
|
||||
|
||||
# track metrics for callbacks
|
||||
all_callback_metrics.append(callback_metrics)
|
||||
|
||||
# track progress bar metrics
|
||||
self.__add_tqdm_metrics(progress_bar_metrics)
|
||||
|
@ -1404,6 +1418,10 @@ class Trainer(TrainerIO):
|
|||
|
||||
# collapse all metrics into one dict
|
||||
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
|
||||
|
||||
# track all metrics for callbacks
|
||||
self.callback_metrics = {k: v for d in all_callback_metrics for k, v in d.items()}
|
||||
|
||||
return 0, grad_norm_dic, all_log_metrics
|
||||
|
||||
def __run_evaluation(self, test=False):
|
||||
|
@ -1444,14 +1462,17 @@ class Trainer(TrainerIO):
|
|||
dataloaders,
|
||||
max_batches,
|
||||
test)
|
||||
_, progress_bar_metrics, log_metrics = self.__process_output(eval_results)
|
||||
_, prog_bar_metrics, log_metrics, callback_metrics = self.__process_output(eval_results)
|
||||
|
||||
# add metrics to prog bar
|
||||
self.__add_tqdm_metrics(progress_bar_metrics)
|
||||
self.__add_tqdm_metrics(prog_bar_metrics)
|
||||
|
||||
# log metrics
|
||||
self.__log_metrics(log_metrics, {})
|
||||
|
||||
# track metrics for callbacks
|
||||
self.callback_metrics = callback_metrics
|
||||
|
||||
# hook
|
||||
model.on_post_performance_check()
|
||||
|
||||
|
@ -1463,4 +1484,4 @@ class Trainer(TrainerIO):
|
|||
# model checkpointing
|
||||
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
|
||||
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
|
||||
logs=self.__training_tqdm_dict)
|
||||
logs=self.callback_metrics)
|
||||
|
|
Loading…
Reference in New Issue