ref: moved eval loop logging to loggers 1/n (#3408)

This commit is contained in:
William Falcon 2020-09-09 01:05:50 -04:00 committed by GitHub
parent 8f6b115511
commit 0c2e315950
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 75 deletions

View File

@ -141,9 +141,14 @@ class EvaluationLoop(object):
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
return eval_results
def log_epoch_metrics(self, eval_results):
def log_epoch_metrics(self, eval_results, test_mode):
using_eval_result = self.is_using_eval_results()
self.trainer.logger_connector.on_evaluation_epoch_end(eval_results, using_eval_result)
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
eval_results,
using_eval_result,
test_mode
)
return eval_loop_results
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
model = self.trainer.get_model()

View File

@ -124,33 +124,16 @@ In this second case, the options you pass to trainer will be used when running
"""
from abc import ABC, abstractmethod
from pprint import pprint
from typing import Callable, List, Union
from typing import Callable, List
import torch
from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.logger_connector import LoggerConnector
try:
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True
class TrainerEvaluationLoopMixin(ABC):
@ -265,15 +248,12 @@ class TrainerEvaluationLoopMixin(ABC):
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
# bookkeeping
self.evaluation_loop.log_epoch_metrics(eval_results)
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
self.evaluation_loop.predictions.to_disk()
# hook
self.evaluation_loop.on_evaluation_epoch_end()
# log the final eval loop metrics
eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode)
# enable train mode again
model.train()
torch.set_grad_enabled(True)
@ -282,51 +262,3 @@ class TrainerEvaluationLoopMixin(ABC):
self.evaluation_loop.on_evaluation_end()
return eval_loop_results, eval_results
def __log_evaluation_epoch_metrics(self, eval_results, test_mode):
if self.running_sanity_check:
return
eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:
# in eval, the user may return something at every validation step without final reduction
if not isinstance(eval_results, list):
eval_results = [eval_results]
for result_idx, result in enumerate(eval_results):
if isinstance(result, EvalResult):
prog_bar_metrics = result.epoch_pbar_metrics
log_metrics = result.epoch_log_metrics
callback_metrics = result.callback_metrics
# in testing we don't need the callback metrics
if test_mode:
callback_metrics = {}
else:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result)
# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}
# add metrics to prog bar
self.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
# log metrics
self.logger_connector.log_metrics(log_metrics, {})
# track metrics for callbacks
self.logger_connector.callback_metrics.update(callback_metrics)
if len(dataloader_result_metrics) > 0:
eval_loop_results.append(dataloader_result_metrics)
# log results of test
if test_mode and self.is_global_zero and self.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
pprint(results)
print('-' * 80)
return eval_loop_results

View File

@ -15,7 +15,8 @@ import torch
from pytorch_lightning.core import memory
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.core.step_result import EvalResult, Result
from pprint import pprint
class LoggerConnector:
@ -73,7 +74,12 @@ class LoggerConnector:
self.trainer.dev_debugger.track_pbar_metrics_history(metrics)
def on_evaluation_epoch_end(self, eval_results, using_eval_result):
def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode):
# TODO: merge both functions?
self._log_on_evaluation_epoch_end_metrics(eval_results, using_eval_result)
return self.__log_evaluation_epoch_metrics_2(eval_results, test_mode)
def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result):
if using_eval_result:
if isinstance(eval_results, list):
for eval_result in eval_results:
@ -97,6 +103,54 @@ class LoggerConnector:
flat = flatten_dict(eval_results)
self.trainer.logger_connector.callback_metrics.update(flat)
def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):
if self.trainer.running_sanity_check:
return
eval_loop_results = []
if eval_results is not None and len(eval_results) > 0:
# in eval, the user may return something at every validation step without final reduction
if not isinstance(eval_results, list):
eval_results = [eval_results]
for result_idx, result in enumerate(eval_results):
if isinstance(result, EvalResult):
prog_bar_metrics = result.epoch_pbar_metrics
log_metrics = result.epoch_log_metrics
callback_metrics = result.callback_metrics
# in testing we don't need the callback metrics
if test_mode:
callback_metrics = {}
else:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_output(result)
# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}
# add metrics to prog bar
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
# log metrics
self.trainer.logger_connector.log_metrics(log_metrics, {})
# track metrics for callbacks
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
if len(dataloader_result_metrics) > 0:
eval_loop_results.append(dataloader_result_metrics)
# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
pprint(results)
print('-' * 80)
return eval_loop_results
def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator,
early_stopping_accumulator, num_optimizers)