From 0c2e315950409e84358be50a443c05165b68bcd2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 9 Sep 2020 01:05:50 -0400 Subject: [PATCH] ref: moved eval loop logging to loggers 1/n (#3408) --- pytorch_lightning/trainer/evaluate_loop.py | 9 ++- pytorch_lightning/trainer/evaluation_loop.py | 74 +------------------ pytorch_lightning/trainer/logger_connector.py | 58 ++++++++++++++- 3 files changed, 66 insertions(+), 75 deletions(-) diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py index ddeed31dec..824cd36512 100644 --- a/pytorch_lightning/trainer/evaluate_loop.py +++ b/pytorch_lightning/trainer/evaluate_loop.py @@ -141,9 +141,14 @@ class EvaluationLoop(object): eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) return eval_results - def log_epoch_metrics(self, eval_results): + def log_epoch_metrics(self, eval_results, test_mode): using_eval_result = self.is_using_eval_results() - self.trainer.logger_connector.on_evaluation_epoch_end(eval_results, using_eval_result) + eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( + eval_results, + using_eval_result, + test_mode + ) + return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): model = self.trainer.get_model() diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index caba7defe6..c8b5ea3312 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -124,33 +124,16 @@ In this second case, the options you pass to trainer will be used when running """ from abc import ABC, abstractmethod -from pprint import pprint -from typing import Callable, List, Union +from typing import Callable, List import torch from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType -from pytorch_lightning.core.step_result import EvalResult, Result +from pytorch_lightning.utilities import AMPType from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop from pytorch_lightning.trainer.logger_connector import LoggerConnector -try: - import torch_xla.distributed.parallel_loader as xla_pl - import torch_xla.core.xla_model as xm -except ImportError: - XLA_AVAILABLE = False -else: - XLA_AVAILABLE = True - -try: - import horovod.torch as hvd -except (ModuleNotFoundError, ImportError): - HOROVOD_AVAILABLE = False -else: - HOROVOD_AVAILABLE = True - class TrainerEvaluationLoopMixin(ABC): @@ -265,15 +248,12 @@ class TrainerEvaluationLoopMixin(ABC): eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) # bookkeeping - self.evaluation_loop.log_epoch_metrics(eval_results) + eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode) self.evaluation_loop.predictions.to_disk() # hook self.evaluation_loop.on_evaluation_epoch_end() - # log the final eval loop metrics - eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode) - # enable train mode again model.train() torch.set_grad_enabled(True) @@ -282,51 +262,3 @@ class TrainerEvaluationLoopMixin(ABC): self.evaluation_loop.on_evaluation_end() return eval_loop_results, eval_results - - def __log_evaluation_epoch_metrics(self, eval_results, test_mode): - if self.running_sanity_check: - return - - eval_loop_results = [] - if eval_results is not None and len(eval_results) > 0: - - # in eval, the user may return something at every validation step without final reduction - if not isinstance(eval_results, list): - eval_results = [eval_results] - - for result_idx, result in enumerate(eval_results): - if isinstance(result, EvalResult): - prog_bar_metrics = result.epoch_pbar_metrics - log_metrics = result.epoch_log_metrics - callback_metrics = result.callback_metrics - - # in testing we don't need the callback metrics - if test_mode: - callback_metrics = {} - else: - _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result) - - # eval loop returns all metrics - dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics} - - # add metrics to prog bar - self.logger_connector.add_progress_bar_metrics(prog_bar_metrics) - - # log metrics - self.logger_connector.log_metrics(log_metrics, {}) - - # track metrics for callbacks - self.logger_connector.callback_metrics.update(callback_metrics) - - if len(dataloader_result_metrics) > 0: - eval_loop_results.append(dataloader_result_metrics) - - # log results of test - if test_mode and self.is_global_zero and self.verbose_test: - print('-' * 80) - for result_idx, results in enumerate(eval_loop_results): - print(f'DATALOADER:{result_idx} TEST RESULTS') - pprint(results) - print('-' * 80) - - return eval_loop_results diff --git a/pytorch_lightning/trainer/logger_connector.py b/pytorch_lightning/trainer/logger_connector.py index 3148f4ac59..883f344869 100644 --- a/pytorch_lightning/trainer/logger_connector.py +++ b/pytorch_lightning/trainer/logger_connector.py @@ -15,7 +15,8 @@ import torch from pytorch_lightning.core import memory from pytorch_lightning.utilities import flatten_dict from pytorch_lightning.utilities.model_utils import is_overridden -from pytorch_lightning.core.step_result import Result +from pytorch_lightning.core.step_result import EvalResult, Result +from pprint import pprint class LoggerConnector: @@ -73,7 +74,12 @@ class LoggerConnector: self.trainer.dev_debugger.track_pbar_metrics_history(metrics) - def on_evaluation_epoch_end(self, eval_results, using_eval_result): + def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode): + # TODO: merge both functions? + self._log_on_evaluation_epoch_end_metrics(eval_results, using_eval_result) + return self.__log_evaluation_epoch_metrics_2(eval_results, test_mode) + + def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result): if using_eval_result: if isinstance(eval_results, list): for eval_result in eval_results: @@ -97,6 +103,54 @@ class LoggerConnector: flat = flatten_dict(eval_results) self.trainer.logger_connector.callback_metrics.update(flat) + def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode): + if self.trainer.running_sanity_check: + return + + eval_loop_results = [] + if eval_results is not None and len(eval_results) > 0: + + # in eval, the user may return something at every validation step without final reduction + if not isinstance(eval_results, list): + eval_results = [eval_results] + + for result_idx, result in enumerate(eval_results): + if isinstance(result, EvalResult): + prog_bar_metrics = result.epoch_pbar_metrics + log_metrics = result.epoch_log_metrics + callback_metrics = result.callback_metrics + + # in testing we don't need the callback metrics + if test_mode: + callback_metrics = {} + else: + _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_output(result) + + # eval loop returns all metrics + dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics} + + # add metrics to prog bar + self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics) + + # log metrics + self.trainer.logger_connector.log_metrics(log_metrics, {}) + + # track metrics for callbacks + self.trainer.logger_connector.callback_metrics.update(callback_metrics) + + if len(dataloader_result_metrics) > 0: + eval_loop_results.append(dataloader_result_metrics) + + # log results of test + if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: + print('-' * 80) + for result_idx, results in enumerate(eval_loop_results): + print(f'DATALOADER:{result_idx} TEST RESULTS') + pprint(results) + print('-' * 80) + + return eval_loop_results + def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers): self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers)