From a46130cdc19447ea86466bf4fcd0c7226e968a05 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 12 Aug 2020 08:02:00 -0400 Subject: [PATCH] add weighted average to results obj (#2930) * track batch size in result obj --- pytorch_lightning/core/step_result.py | 29 ++++++++++++++++++-- pytorch_lightning/trainer/evaluation_loop.py | 8 +++++- pytorch_lightning/trainer/training_loop.py | 13 +++++---- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index ea62fdab2e..eea6e07822 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -36,7 +36,8 @@ class Result(Dict): self['meta'] = { '_internal': { - '_reduce_on_epoch': False + '_reduce_on_epoch': False, + 'batch_sizes': [] } } @@ -166,6 +167,14 @@ class Result(Dict): _internal = self['meta']['_internal'] _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) + def track_batch_size(self, batch_size): + meta = self['meta'] + meta['_internal']['batch_sizes'].append(batch_size) + + def get_batch_sizes(self): + meta = self['meta'] + return torch.tensor(meta['_internal']['batch_sizes']) + def get_callback_metrics(self) -> dict: result = { 'early_stop_on': self.early_stop_on, @@ -301,18 +310,27 @@ class Result(Dict): @classmethod def reduce_on_epoch_end(cls, outputs): + # get the batch sizes for all outputs + batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1) + meta = outputs[0]['meta'] result = cls() result = recursive_gather(outputs, result) recursive_stack(result) + for k, option in meta.items(): if k == '_internal': continue if option['on_epoch']: fx = option['reduce_fx'] - result[k] = fx(result[k]) + if fx == torch.mean: + reduced_val = weighted_mean(result[k], batch_sizes) + else: + reduced_val = fx(result[k]) + + result[k] = reduced_val result['meta'] = meta return result @@ -713,3 +731,10 @@ class EvalResult(Result): } return result + + +def weighted_mean(result, weights): + weights = weights.to(result.device) + numerator = torch.dot(result.float(), weights.t().float()) + result = numerator / weights.sum().float() + return result diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 433ea97087..7e90eb6dc6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -331,8 +331,14 @@ class TrainerEvaluationLoopMixin(ABC): else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) + is_result_obj = isinstance(output, Result) + + # track batch size for weighted average + if is_result_obj: + output.track_batch_size(len(batch)) + # allow only EvalResult when using structured results (from val_step) - if isinstance(output, Result) and not isinstance(output, EvalResult): + if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ec5bd0938d..8bd0dea623 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -848,16 +848,13 @@ class TrainerTrainLoopMixin(ABC): # add metrics to loggers if using_results_obj: metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics - else: - metrics_to_log = opt_closure_result.training_step_output.log_metrics - batch_log_metrics.append(metrics_to_log) - - # add metrics to progress bar - if using_results_obj: step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics else: + metrics_to_log = opt_closure_result.training_step_output.log_metrics step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end + # track metrics + batch_log_metrics.append(metrics_to_log) if len(step_pbar_metrics) > 0: self.add_progress_bar_metrics(step_pbar_metrics) @@ -1018,6 +1015,10 @@ class TrainerTrainLoopMixin(ABC): training_step_output_for_epoch_end = training_step_output is_result_obj = isinstance(training_step_output, Result) + # track batch size for weighted average + if is_result_obj: + training_step_output.track_batch_size(len(split_batch)) + # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException('training_step cannot return EvalResult, '