add weighted average to results obj (#2930)

* track batch size in result obj
This commit is contained in:
William Falcon 2020-08-12 08:02:00 -04:00 committed by GitHub
parent 118bd14d16
commit a46130cdc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 9 deletions

View File

@ -36,7 +36,8 @@ class Result(Dict):
self['meta'] = {
'_internal': {
'_reduce_on_epoch': False
'_reduce_on_epoch': False,
'batch_sizes': []
}
}
@ -166,6 +167,14 @@ class Result(Dict):
_internal = self['meta']['_internal']
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
def track_batch_size(self, batch_size):
meta = self['meta']
meta['_internal']['batch_sizes'].append(batch_size)
def get_batch_sizes(self):
meta = self['meta']
return torch.tensor(meta['_internal']['batch_sizes'])
def get_callback_metrics(self) -> dict:
result = {
'early_stop_on': self.early_stop_on,
@ -301,18 +310,27 @@ class Result(Dict):
@classmethod
def reduce_on_epoch_end(cls, outputs):
# get the batch sizes for all outputs
batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1)
meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
for k, option in meta.items():
if k == '_internal':
continue
if option['on_epoch']:
fx = option['reduce_fx']
result[k] = fx(result[k])
if fx == torch.mean:
reduced_val = weighted_mean(result[k], batch_sizes)
else:
reduced_val = fx(result[k])
result[k] = reduced_val
result['meta'] = meta
return result
@ -713,3 +731,10 @@ class EvalResult(Result):
}
return result
def weighted_mean(result, weights):
weights = weights.to(result.device)
numerator = torch.dot(result.float(), weights.t().float())
result = numerator / weights.sum().float()
return result

View File

@ -331,8 +331,14 @@ class TrainerEvaluationLoopMixin(ABC):
else:
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
is_result_obj = isinstance(output, Result)
# track batch size for weighted average
if is_result_obj:
output.track_batch_size(len(batch))
# allow only EvalResult when using structured results (from val_step)
if isinstance(output, Result) and not isinstance(output, EvalResult):
if is_result_obj and not isinstance(output, EvalResult):
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)

View File

@ -848,16 +848,13 @@ class TrainerTrainLoopMixin(ABC):
# add metrics to loggers
if using_results_obj:
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
batch_log_metrics.append(metrics_to_log)
# add metrics to progress bar
if using_results_obj:
step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
# track metrics
batch_log_metrics.append(metrics_to_log)
if len(step_pbar_metrics) > 0:
self.add_progress_bar_metrics(step_pbar_metrics)
@ -1018,6 +1015,10 @@ class TrainerTrainLoopMixin(ABC):
training_step_output_for_epoch_end = training_step_output
is_result_obj = isinstance(training_step_output, Result)
# track batch size for weighted average
if is_result_obj:
training_step_output.track_batch_size(len(split_batch))
# don't allow EvalResult in the training_step
if isinstance(training_step_output, EvalResult):
raise MisconfigurationException('training_step cannot return EvalResult, '