ref: refactored inner eval loop (#3141)

* refactored dataloader process hook

* refactored dataloader process hook

* refactored dataloader process hook
This commit is contained in:
William Falcon 2020-08-24 22:50:59 -04:00 committed by GitHub
parent f064d74be8
commit ccc923cbb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 74 deletions

View File

@ -1,6 +1,7 @@
import torch
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.core.step_result import EvalResult
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class EvaluationLoop(object):
@ -43,11 +44,38 @@ class EvaluationLoop(object):
else:
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
def evaluation_step(self, *args, **kwargs):
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]
multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1)
multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1)
if multiple_test_loaders or multiple_val_loaders:
args.append(dataloader_idx)
return args
def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
# configure args
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
# run actual test step
if self.testing:
output = self.trainer.accelerator_backend.test_step(*args, **kwargs)
output = self.trainer.accelerator_backend.test_step(args)
else:
output = self.trainer.accelerator_backend.validation_step(*args, **kwargs)
output = self.trainer.accelerator_backend.validation_step(args)
# track batch size for weighted average
is_result_obj = isinstance(output, Result)
if is_result_obj:
output.track_batch_size(len(batch))
# allow only EvalResult when using structured results (from val_step)
if is_result_obj and not isinstance(output, EvalResult):
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)
return output
def evaluation_step_end(self, *args, **kwargs):
@ -69,8 +97,37 @@ class EvaluationLoop(object):
else:
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
# Add step predictions to prediction collection to write later
if output is not None:
do_write_predictions = isinstance(output, Result) and self.testing
if do_write_predictions:
self.predictions.add(output.pop('predictions', None))
# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)
def on_evaluation_epoch_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
def log_metrics(self, output, batch_idx):
if self.trainer.running_sanity_check:
return
if isinstance(output, EvalResult):
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics
if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
for k, v in step_log_metrics.items():
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v
self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx)
if len(step_pbar_metrics) > 0:
self.trainer.add_progress_bar_metrics(step_pbar_metrics)

View File

@ -132,8 +132,7 @@ from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
try:
@ -273,55 +272,19 @@ class TrainerEvaluationLoopMixin(ABC):
if batch_idx >= dl_max_batches:
break
# -----------------
# eval_batch_start
# -----------------
# val loop hooks
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
# -----------------
# RUN EVALUATION STEP
# -----------------
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step(args)
# track batch size for weighted average
is_result_obj = isinstance(output, Result)
if is_result_obj:
output.track_batch_size(len(batch))
# allow only EvalResult when using structured results (from val_step)
if is_result_obj and not isinstance(output, EvalResult):
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)
# ------------------
# EVAL STEP END
# ------------------
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)
# ------------------
# Hook: on_eval_batch_end
# ------------------
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
# ----------------------
# Post processing
# ----------------------
# track outputs for collation
# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_metrics(output, batch_idx)
if output is not None:
# Add step predictions to prediction collection to write later
do_write_predictions = is_result_obj and test_mode
if do_write_predictions:
self.evaluation_loop.predictions.add(output.pop('predictions', None))
dl_outputs.append(output)
self.__eval_add_step_metrics(output, batch_idx)
# track debug metrics
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)
self.evaluation_loop.outputs.append(dl_outputs)
# ---------------------
@ -454,23 +417,6 @@ class TrainerEvaluationLoopMixin(ABC):
eval_results = eval_results[0]
return eval_results
def __eval_add_step_metrics(self, output, batch_idx):
# track step level metrics
if isinstance(output, EvalResult) and not self.running_sanity_check:
step_log_metrics = output.batch_log_metrics
step_pbar_metrics = output.batch_pbar_metrics
if len(step_log_metrics) > 0:
# make the metrics appear as a different line in the same graph
metrics_by_epoch = {}
for k, v in step_log_metrics.items():
metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v
self.log_metrics(metrics_by_epoch, {}, step=batch_idx)
if len(step_pbar_metrics) > 0:
self.add_progress_bar_metrics(step_pbar_metrics)
def __auto_reduce_result_objs(self, outputs):
# outputs has a list of results per dataloader
eval_results = []
@ -588,12 +534,3 @@ class TrainerEvaluationLoopMixin(ABC):
print('-' * 80)
return eval_loop_results
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
# make dataloader_idx arg in validation_step optional
args = [batch, batch_idx]
if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1):
args.append(dataloader_idx)
return args