ref: refactored inner eval loop (#3141)
* refactored dataloader process hook * refactored dataloader process hook * refactored dataloader process hook
This commit is contained in:
parent
f064d74be8
commit
ccc923cbb0
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.trainer.supporters import PredictionCollection
|
from pytorch_lightning.trainer.supporters import PredictionCollection
|
||||||
from pytorch_lightning.core.step_result import EvalResult
|
from pytorch_lightning.core.step_result import Result, EvalResult
|
||||||
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
|
|
||||||
class EvaluationLoop(object):
|
class EvaluationLoop(object):
|
||||||
|
@ -43,11 +44,38 @@ class EvaluationLoop(object):
|
||||||
else:
|
else:
|
||||||
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
|
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
|
||||||
|
|
||||||
def evaluation_step(self, *args, **kwargs):
|
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||||
|
# make dataloader_idx arg in validation_step optional
|
||||||
|
args = [batch, batch_idx]
|
||||||
|
|
||||||
|
multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1)
|
||||||
|
multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1)
|
||||||
|
|
||||||
|
if multiple_test_loaders or multiple_val_loaders:
|
||||||
|
args.append(dataloader_idx)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||||
|
# configure args
|
||||||
|
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
||||||
|
|
||||||
|
# run actual test step
|
||||||
if self.testing:
|
if self.testing:
|
||||||
output = self.trainer.accelerator_backend.test_step(*args, **kwargs)
|
output = self.trainer.accelerator_backend.test_step(args)
|
||||||
else:
|
else:
|
||||||
output = self.trainer.accelerator_backend.validation_step(*args, **kwargs)
|
output = self.trainer.accelerator_backend.validation_step(args)
|
||||||
|
|
||||||
|
# track batch size for weighted average
|
||||||
|
is_result_obj = isinstance(output, Result)
|
||||||
|
if is_result_obj:
|
||||||
|
output.track_batch_size(len(batch))
|
||||||
|
|
||||||
|
# allow only EvalResult when using structured results (from val_step)
|
||||||
|
if is_result_obj and not isinstance(output, EvalResult):
|
||||||
|
m = 'only EvalResults or dicts are allowed from validation_step'
|
||||||
|
raise MisconfigurationException(m)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def evaluation_step_end(self, *args, **kwargs):
|
def evaluation_step_end(self, *args, **kwargs):
|
||||||
|
@ -69,8 +97,37 @@ class EvaluationLoop(object):
|
||||||
else:
|
else:
|
||||||
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
|
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
|
||||||
|
|
||||||
|
def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
|
||||||
|
# Add step predictions to prediction collection to write later
|
||||||
|
if output is not None:
|
||||||
|
do_write_predictions = isinstance(output, Result) and self.testing
|
||||||
|
if do_write_predictions:
|
||||||
|
self.predictions.add(output.pop('predictions', None))
|
||||||
|
|
||||||
|
# track debug metrics
|
||||||
|
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)
|
||||||
|
|
||||||
def on_evaluation_epoch_end(self, *args, **kwargs):
|
def on_evaluation_epoch_end(self, *args, **kwargs):
|
||||||
if self.testing:
|
if self.testing:
|
||||||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
||||||
|
|
||||||
|
def log_metrics(self, output, batch_idx):
|
||||||
|
if self.trainer.running_sanity_check:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(output, EvalResult):
|
||||||
|
step_log_metrics = output.batch_log_metrics
|
||||||
|
step_pbar_metrics = output.batch_pbar_metrics
|
||||||
|
|
||||||
|
if len(step_log_metrics) > 0:
|
||||||
|
# make the metrics appear as a different line in the same graph
|
||||||
|
metrics_by_epoch = {}
|
||||||
|
for k, v in step_log_metrics.items():
|
||||||
|
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v
|
||||||
|
|
||||||
|
self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx)
|
||||||
|
|
||||||
|
if len(step_pbar_metrics) > 0:
|
||||||
|
self.trainer.add_progress_bar_metrics(step_pbar_metrics)
|
||||||
|
|
|
@ -132,8 +132,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
from pytorch_lightning.core.lightning import LightningModule
|
||||||
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
|
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
|
||||||
from pytorch_lightning.core.step_result import Result, EvalResult
|
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
||||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -273,55 +272,19 @@ class TrainerEvaluationLoopMixin(ABC):
|
||||||
if batch_idx >= dl_max_batches:
|
if batch_idx >= dl_max_batches:
|
||||||
break
|
break
|
||||||
|
|
||||||
# -----------------
|
# val loop hooks
|
||||||
# eval_batch_start
|
|
||||||
# -----------------
|
|
||||||
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
|
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
|
||||||
|
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
|
||||||
# -----------------
|
|
||||||
# RUN EVALUATION STEP
|
|
||||||
# -----------------
|
|
||||||
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
|
||||||
output = self.evaluation_loop.evaluation_step(args)
|
|
||||||
|
|
||||||
# track batch size for weighted average
|
|
||||||
is_result_obj = isinstance(output, Result)
|
|
||||||
if is_result_obj:
|
|
||||||
output.track_batch_size(len(batch))
|
|
||||||
|
|
||||||
# allow only EvalResult when using structured results (from val_step)
|
|
||||||
if is_result_obj and not isinstance(output, EvalResult):
|
|
||||||
m = 'only EvalResults or dicts are allowed from validation_step'
|
|
||||||
raise MisconfigurationException(m)
|
|
||||||
|
|
||||||
# ------------------
|
|
||||||
# EVAL STEP END
|
|
||||||
# ------------------
|
|
||||||
output = self.evaluation_loop.evaluation_step_end(output)
|
output = self.evaluation_loop.evaluation_step_end(output)
|
||||||
|
|
||||||
# ------------------
|
|
||||||
# Hook: on_eval_batch_end
|
|
||||||
# ------------------
|
|
||||||
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
|
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
|
||||||
|
|
||||||
# ----------------------
|
# clean up
|
||||||
# Post processing
|
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
|
||||||
# ----------------------
|
self.evaluation_loop.log_metrics(output, batch_idx)
|
||||||
# track outputs for collation
|
|
||||||
if output is not None:
|
if output is not None:
|
||||||
|
|
||||||
# Add step predictions to prediction collection to write later
|
|
||||||
do_write_predictions = is_result_obj and test_mode
|
|
||||||
if do_write_predictions:
|
|
||||||
self.evaluation_loop.predictions.add(output.pop('predictions', None))
|
|
||||||
|
|
||||||
dl_outputs.append(output)
|
dl_outputs.append(output)
|
||||||
|
|
||||||
self.__eval_add_step_metrics(output, batch_idx)
|
|
||||||
|
|
||||||
# track debug metrics
|
|
||||||
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)
|
|
||||||
|
|
||||||
self.evaluation_loop.outputs.append(dl_outputs)
|
self.evaluation_loop.outputs.append(dl_outputs)
|
||||||
|
|
||||||
# ---------------------
|
# ---------------------
|
||||||
|
@ -454,23 +417,6 @@ class TrainerEvaluationLoopMixin(ABC):
|
||||||
eval_results = eval_results[0]
|
eval_results = eval_results[0]
|
||||||
return eval_results
|
return eval_results
|
||||||
|
|
||||||
def __eval_add_step_metrics(self, output, batch_idx):
|
|
||||||
# track step level metrics
|
|
||||||
if isinstance(output, EvalResult) and not self.running_sanity_check:
|
|
||||||
step_log_metrics = output.batch_log_metrics
|
|
||||||
step_pbar_metrics = output.batch_pbar_metrics
|
|
||||||
|
|
||||||
if len(step_log_metrics) > 0:
|
|
||||||
# make the metrics appear as a different line in the same graph
|
|
||||||
metrics_by_epoch = {}
|
|
||||||
for k, v in step_log_metrics.items():
|
|
||||||
metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v
|
|
||||||
|
|
||||||
self.log_metrics(metrics_by_epoch, {}, step=batch_idx)
|
|
||||||
|
|
||||||
if len(step_pbar_metrics) > 0:
|
|
||||||
self.add_progress_bar_metrics(step_pbar_metrics)
|
|
||||||
|
|
||||||
def __auto_reduce_result_objs(self, outputs):
|
def __auto_reduce_result_objs(self, outputs):
|
||||||
# outputs has a list of results per dataloader
|
# outputs has a list of results per dataloader
|
||||||
eval_results = []
|
eval_results = []
|
||||||
|
@ -588,12 +534,3 @@ class TrainerEvaluationLoopMixin(ABC):
|
||||||
print('-' * 80)
|
print('-' * 80)
|
||||||
|
|
||||||
return eval_loop_results
|
return eval_loop_results
|
||||||
|
|
||||||
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
|
|
||||||
# make dataloader_idx arg in validation_step optional
|
|
||||||
args = [batch, batch_idx]
|
|
||||||
|
|
||||||
if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1):
|
|
||||||
args.append(dataloader_idx)
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
Loading…
Reference in New Issue