diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 877f74948a..e0aa2e6d34 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -125,7 +125,8 @@ class EarlyStopping(Callback): print('Early stopping conditioned on metric `%s` ' 'which is not available. Available metrics are: %s' % (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning) - exit(-1) + stop_training = True + return stop_training if self.monitor_op(current - self.min_delta, self.best): self.best = current diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7a7c8eed7c..46de7a008b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -148,6 +148,7 @@ class Trainer(TrainerIO): self.avg_loss = 0 self.batch_nb = 0 self.tqdm_metrics = {} + self.callback_metrics = {} self.nb_val_batches = 0 self.nb_training_batches = 0 self.nb_test_batches = 0 @@ -1061,7 +1062,7 @@ class Trainer(TrainerIO): met_min_epochs = epoch_nb > self.min_nb_epochs if self.enable_early_stop and met_min_epochs: should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_nb, - logs=self.__training_tqdm_dict) + logs=self.callback_metrics) # stop training stop = should_stop and met_min_epochs if stop: @@ -1237,8 +1238,9 @@ class Trainer(TrainerIO): output = self.model.training_step(*args) # format and reduce outputs accordingly - loss, progress_bar_metrics, log_metrics = self.__process_output(output, train=True) - return loss, progress_bar_metrics, log_metrics + output = self.__process_output(output, train=True) + loss, progress_bar_metrics, log_metrics, callback_metrics = output + return loss, progress_bar_metrics, log_metrics, callback_metrics def __process_output(self, output, train=False): """ @@ -1247,6 +1249,12 @@ class Trainer(TrainerIO): :param output: :return: """ + # all keys not progress_bar or log are candidates for callbacks + callback_metrics = {} + for k, v in output.items(): + if k not in ['progress_bar', 'log']: + callback_metrics[k] = v + try: progress_output = output['progress_bar'] @@ -1293,7 +1301,7 @@ class Trainer(TrainerIO): if self.use_dp or self.use_ddp2: loss = reduce_distributed_output(loss, self.num_gpus) - return loss, progress_bar_metrics, log_metrics + return loss, progress_bar_metrics, log_metrics, callback_metrics def __clip_gradients(self): if self.gradient_clip_val > 0: @@ -1310,6 +1318,9 @@ class Trainer(TrainerIO): # track grad norms grad_norm_dic = {} + # track all metrics for callbacks + all_callback_metrics = [] + # track metrics to log all_log_metrics = [] @@ -1334,7 +1345,10 @@ class Trainer(TrainerIO): def optimizer_closure(): # forward pass output = self.__training_forward(batch, batch_nb, opt_idx) - closure_loss, progress_bar_metrics, log_metrics = output + closure_loss, progress_bar_metrics, log_metrics, callback_metrics = output + + # track metrics for callbacks + all_callback_metrics.append(callback_metrics) # track progress bar metrics self.__add_tqdm_metrics(progress_bar_metrics) @@ -1404,6 +1418,10 @@ class Trainer(TrainerIO): # collapse all metrics into one dict all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} + + # track all metrics for callbacks + self.callback_metrics = {k: v for d in all_callback_metrics for k, v in d.items()} + return 0, grad_norm_dic, all_log_metrics def __run_evaluation(self, test=False): @@ -1444,14 +1462,17 @@ class Trainer(TrainerIO): dataloaders, max_batches, test) - _, progress_bar_metrics, log_metrics = self.__process_output(eval_results) + _, prog_bar_metrics, log_metrics, callback_metrics = self.__process_output(eval_results) # add metrics to prog bar - self.__add_tqdm_metrics(progress_bar_metrics) + self.__add_tqdm_metrics(prog_bar_metrics) # log metrics self.__log_metrics(log_metrics, {}) + # track metrics for callbacks + self.callback_metrics = callback_metrics + # hook model.on_post_performance_check() @@ -1463,4 +1484,4 @@ class Trainer(TrainerIO): # model checkpointing if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, - logs=self.__training_tqdm_dict) + logs=self.callback_metrics)