diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py index d1d8977843..855c60dfb7 100644 --- a/pytorch_lightning/trainer/evaluate_loop.py +++ b/pytorch_lightning/trainer/evaluate_loop.py @@ -1,6 +1,7 @@ import torch from pytorch_lightning.trainer.supporters import PredictionCollection -from pytorch_lightning.core.step_result import EvalResult +from pytorch_lightning.core.step_result import Result, EvalResult +from pytorch_lightning.utilities.exceptions import MisconfigurationException class EvaluationLoop(object): @@ -43,11 +44,38 @@ class EvaluationLoop(object): else: self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - def evaluation_step(self, *args, **kwargs): + def build_args(self, test_mode, batch, batch_idx, dataloader_idx): + # make dataloader_idx arg in validation_step optional + args = [batch, batch_idx] + + multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) + multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) + + if multiple_test_loaders or multiple_val_loaders: + args.append(dataloader_idx) + + return args + + def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): + # configure args + args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) + + # run actual test step if self.testing: - output = self.trainer.accelerator_backend.test_step(*args, **kwargs) + output = self.trainer.accelerator_backend.test_step(args) else: - output = self.trainer.accelerator_backend.validation_step(*args, **kwargs) + output = self.trainer.accelerator_backend.validation_step(args) + + # track batch size for weighted average + is_result_obj = isinstance(output, Result) + if is_result_obj: + output.track_batch_size(len(batch)) + + # allow only EvalResult when using structured results (from val_step) + if is_result_obj and not isinstance(output, EvalResult): + m = 'only EvalResults or dicts are allowed from validation_step' + raise MisconfigurationException(m) + return output def evaluation_step_end(self, *args, **kwargs): @@ -69,8 +97,37 @@ class EvaluationLoop(object): else: self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) + def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): + # Add step predictions to prediction collection to write later + if output is not None: + do_write_predictions = isinstance(output, Result) and self.testing + if do_write_predictions: + self.predictions.add(output.pop('predictions', None)) + + # track debug metrics + self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) + def on_evaluation_epoch_end(self, *args, **kwargs): if self.testing: self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) else: self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) + + def log_metrics(self, output, batch_idx): + if self.trainer.running_sanity_check: + return + + if isinstance(output, EvalResult): + step_log_metrics = output.batch_log_metrics + step_pbar_metrics = output.batch_pbar_metrics + + if len(step_log_metrics) > 0: + # make the metrics appear as a different line in the same graph + metrics_by_epoch = {} + for k, v in step_log_metrics.items(): + metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v + + self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx) + + if len(step_pbar_metrics) > 0: + self.trainer.add_progress_bar_metrics(step_pbar_metrics) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2911bff94f..53d32cf6f9 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -132,8 +132,7 @@ from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType -from pytorch_lightning.core.step_result import Result, EvalResult -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop try: @@ -273,55 +272,19 @@ class TrainerEvaluationLoopMixin(ABC): if batch_idx >= dl_max_batches: break - # ----------------- - # eval_batch_start - # ----------------- + # val loop hooks self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) - - # ----------------- - # RUN EVALUATION STEP - # ----------------- - args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) - output = self.evaluation_loop.evaluation_step(args) - - # track batch size for weighted average - is_result_obj = isinstance(output, Result) - if is_result_obj: - output.track_batch_size(len(batch)) - - # allow only EvalResult when using structured results (from val_step) - if is_result_obj and not isinstance(output, EvalResult): - m = 'only EvalResults or dicts are allowed from validation_step' - raise MisconfigurationException(m) - - # ------------------ - # EVAL STEP END - # ------------------ + output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) output = self.evaluation_loop.evaluation_step_end(output) - - # ------------------ - # Hook: on_eval_batch_end - # ------------------ self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx) - # ---------------------- - # Post processing - # ---------------------- - # track outputs for collation + # clean up + self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) + self.evaluation_loop.log_metrics(output, batch_idx) + if output is not None: - - # Add step predictions to prediction collection to write later - do_write_predictions = is_result_obj and test_mode - if do_write_predictions: - self.evaluation_loop.predictions.add(output.pop('predictions', None)) - dl_outputs.append(output) - self.__eval_add_step_metrics(output, batch_idx) - - # track debug metrics - self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output) - self.evaluation_loop.outputs.append(dl_outputs) # --------------------- @@ -454,23 +417,6 @@ class TrainerEvaluationLoopMixin(ABC): eval_results = eval_results[0] return eval_results - def __eval_add_step_metrics(self, output, batch_idx): - # track step level metrics - if isinstance(output, EvalResult) and not self.running_sanity_check: - step_log_metrics = output.batch_log_metrics - step_pbar_metrics = output.batch_pbar_metrics - - if len(step_log_metrics) > 0: - # make the metrics appear as a different line in the same graph - metrics_by_epoch = {} - for k, v in step_log_metrics.items(): - metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v - - self.log_metrics(metrics_by_epoch, {}, step=batch_idx) - - if len(step_pbar_metrics) > 0: - self.add_progress_bar_metrics(step_pbar_metrics) - def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] @@ -588,12 +534,3 @@ class TrainerEvaluationLoopMixin(ABC): print('-' * 80) return eval_loop_results - - def build_args(self, test_mode, batch, batch_idx, dataloader_idx): - # make dataloader_idx arg in validation_step optional - args = [batch, batch_idx] - - if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1): - args.append(dataloader_idx) - - return args