ref: moved eval loop logging to loggers 1/n (#3408)
This commit is contained in:
parent
8f6b115511
commit
0c2e315950
|
@ -141,9 +141,14 @@ class EvaluationLoop(object):
|
|||
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
|
||||
return eval_results
|
||||
|
||||
def log_epoch_metrics(self, eval_results):
|
||||
def log_epoch_metrics(self, eval_results, test_mode):
|
||||
using_eval_result = self.is_using_eval_results()
|
||||
self.trainer.logger_connector.on_evaluation_epoch_end(eval_results, using_eval_result)
|
||||
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
|
||||
eval_results,
|
||||
using_eval_result,
|
||||
test_mode
|
||||
)
|
||||
return eval_loop_results
|
||||
|
||||
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
|
||||
model = self.trainer.get_model()
|
||||
|
|
|
@ -124,33 +124,16 @@ In this second case, the options you pass to trainer will be used when running
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pprint import pprint
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
|
||||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.logger_connector import LoggerConnector
|
||||
|
||||
try:
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
import torch_xla.core.xla_model as xm
|
||||
except ImportError:
|
||||
XLA_AVAILABLE = False
|
||||
else:
|
||||
XLA_AVAILABLE = True
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
HOROVOD_AVAILABLE = False
|
||||
else:
|
||||
HOROVOD_AVAILABLE = True
|
||||
|
||||
|
||||
class TrainerEvaluationLoopMixin(ABC):
|
||||
|
||||
|
@ -265,15 +248,12 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
|
||||
|
||||
# bookkeeping
|
||||
self.evaluation_loop.log_epoch_metrics(eval_results)
|
||||
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
|
||||
self.evaluation_loop.predictions.to_disk()
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
|
||||
# log the final eval loop metrics
|
||||
eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
|
@ -282,51 +262,3 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
self.evaluation_loop.on_evaluation_end()
|
||||
|
||||
return eval_loop_results, eval_results
|
||||
|
||||
def __log_evaluation_epoch_metrics(self, eval_results, test_mode):
|
||||
if self.running_sanity_check:
|
||||
return
|
||||
|
||||
eval_loop_results = []
|
||||
if eval_results is not None and len(eval_results) > 0:
|
||||
|
||||
# in eval, the user may return something at every validation step without final reduction
|
||||
if not isinstance(eval_results, list):
|
||||
eval_results = [eval_results]
|
||||
|
||||
for result_idx, result in enumerate(eval_results):
|
||||
if isinstance(result, EvalResult):
|
||||
prog_bar_metrics = result.epoch_pbar_metrics
|
||||
log_metrics = result.epoch_log_metrics
|
||||
callback_metrics = result.callback_metrics
|
||||
|
||||
# in testing we don't need the callback metrics
|
||||
if test_mode:
|
||||
callback_metrics = {}
|
||||
else:
|
||||
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result)
|
||||
|
||||
# eval loop returns all metrics
|
||||
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}
|
||||
|
||||
# add metrics to prog bar
|
||||
self.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
|
||||
|
||||
# log metrics
|
||||
self.logger_connector.log_metrics(log_metrics, {})
|
||||
|
||||
# track metrics for callbacks
|
||||
self.logger_connector.callback_metrics.update(callback_metrics)
|
||||
|
||||
if len(dataloader_result_metrics) > 0:
|
||||
eval_loop_results.append(dataloader_result_metrics)
|
||||
|
||||
# log results of test
|
||||
if test_mode and self.is_global_zero and self.verbose_test:
|
||||
print('-' * 80)
|
||||
for result_idx, results in enumerate(eval_loop_results):
|
||||
print(f'DATALOADER:{result_idx} TEST RESULTS')
|
||||
pprint(results)
|
||||
print('-' * 80)
|
||||
|
||||
return eval_loop_results
|
||||
|
|
|
@ -15,7 +15,8 @@ import torch
|
|||
from pytorch_lightning.core import memory
|
||||
from pytorch_lightning.utilities import flatten_dict
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
class LoggerConnector:
|
||||
|
@ -73,7 +74,12 @@ class LoggerConnector:
|
|||
|
||||
self.trainer.dev_debugger.track_pbar_metrics_history(metrics)
|
||||
|
||||
def on_evaluation_epoch_end(self, eval_results, using_eval_result):
|
||||
def on_evaluation_epoch_end(self, eval_results, using_eval_result, test_mode):
|
||||
# TODO: merge both functions?
|
||||
self._log_on_evaluation_epoch_end_metrics(eval_results, using_eval_result)
|
||||
return self.__log_evaluation_epoch_metrics_2(eval_results, test_mode)
|
||||
|
||||
def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result):
|
||||
if using_eval_result:
|
||||
if isinstance(eval_results, list):
|
||||
for eval_result in eval_results:
|
||||
|
@ -97,6 +103,54 @@ class LoggerConnector:
|
|||
flat = flatten_dict(eval_results)
|
||||
self.trainer.logger_connector.callback_metrics.update(flat)
|
||||
|
||||
def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):
|
||||
if self.trainer.running_sanity_check:
|
||||
return
|
||||
|
||||
eval_loop_results = []
|
||||
if eval_results is not None and len(eval_results) > 0:
|
||||
|
||||
# in eval, the user may return something at every validation step without final reduction
|
||||
if not isinstance(eval_results, list):
|
||||
eval_results = [eval_results]
|
||||
|
||||
for result_idx, result in enumerate(eval_results):
|
||||
if isinstance(result, EvalResult):
|
||||
prog_bar_metrics = result.epoch_pbar_metrics
|
||||
log_metrics = result.epoch_log_metrics
|
||||
callback_metrics = result.callback_metrics
|
||||
|
||||
# in testing we don't need the callback metrics
|
||||
if test_mode:
|
||||
callback_metrics = {}
|
||||
else:
|
||||
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_output(result)
|
||||
|
||||
# eval loop returns all metrics
|
||||
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}
|
||||
|
||||
# add metrics to prog bar
|
||||
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
|
||||
|
||||
# log metrics
|
||||
self.trainer.logger_connector.log_metrics(log_metrics, {})
|
||||
|
||||
# track metrics for callbacks
|
||||
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
|
||||
|
||||
if len(dataloader_result_metrics) > 0:
|
||||
eval_loop_results.append(dataloader_result_metrics)
|
||||
|
||||
# log results of test
|
||||
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
|
||||
print('-' * 80)
|
||||
for result_idx, results in enumerate(eval_loop_results):
|
||||
print(f'DATALOADER:{result_idx} TEST RESULTS')
|
||||
pprint(results)
|
||||
print('-' * 80)
|
||||
|
||||
return eval_loop_results
|
||||
|
||||
def on_train_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator, num_optimizers):
|
||||
self.log_train_epoch_end_metrics(epoch_output, checkpoint_accumulator,
|
||||
early_stopping_accumulator, num_optimizers)
|
||||
|
|
Loading…
Reference in New Issue