add weighted average to results obj (#2930)
* track batch size in result obj
This commit is contained in:
parent
118bd14d16
commit
a46130cdc1
|
@ -36,7 +36,8 @@ class Result(Dict):
|
|||
|
||||
self['meta'] = {
|
||||
'_internal': {
|
||||
'_reduce_on_epoch': False
|
||||
'_reduce_on_epoch': False,
|
||||
'batch_sizes': []
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -166,6 +167,14 @@ class Result(Dict):
|
|||
_internal = self['meta']['_internal']
|
||||
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
|
||||
|
||||
def track_batch_size(self, batch_size):
|
||||
meta = self['meta']
|
||||
meta['_internal']['batch_sizes'].append(batch_size)
|
||||
|
||||
def get_batch_sizes(self):
|
||||
meta = self['meta']
|
||||
return torch.tensor(meta['_internal']['batch_sizes'])
|
||||
|
||||
def get_callback_metrics(self) -> dict:
|
||||
result = {
|
||||
'early_stop_on': self.early_stop_on,
|
||||
|
@ -301,18 +310,27 @@ class Result(Dict):
|
|||
|
||||
@classmethod
|
||||
def reduce_on_epoch_end(cls, outputs):
|
||||
# get the batch sizes for all outputs
|
||||
batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1)
|
||||
|
||||
meta = outputs[0]['meta']
|
||||
result = cls()
|
||||
result = recursive_gather(outputs, result)
|
||||
recursive_stack(result)
|
||||
|
||||
|
||||
for k, option in meta.items():
|
||||
if k == '_internal':
|
||||
continue
|
||||
|
||||
if option['on_epoch']:
|
||||
fx = option['reduce_fx']
|
||||
result[k] = fx(result[k])
|
||||
if fx == torch.mean:
|
||||
reduced_val = weighted_mean(result[k], batch_sizes)
|
||||
else:
|
||||
reduced_val = fx(result[k])
|
||||
|
||||
result[k] = reduced_val
|
||||
|
||||
result['meta'] = meta
|
||||
return result
|
||||
|
@ -713,3 +731,10 @@ class EvalResult(Result):
|
|||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def weighted_mean(result, weights):
|
||||
weights = weights.to(result.device)
|
||||
numerator = torch.dot(result.float(), weights.t().float())
|
||||
result = numerator / weights.sum().float()
|
||||
return result
|
||||
|
|
|
@ -331,8 +331,14 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
else:
|
||||
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
|
||||
|
||||
is_result_obj = isinstance(output, Result)
|
||||
|
||||
# track batch size for weighted average
|
||||
if is_result_obj:
|
||||
output.track_batch_size(len(batch))
|
||||
|
||||
# allow only EvalResult when using structured results (from val_step)
|
||||
if isinstance(output, Result) and not isinstance(output, EvalResult):
|
||||
if is_result_obj and not isinstance(output, EvalResult):
|
||||
m = 'only EvalResults or dicts are allowed from validation_step'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
|
|
|
@ -848,16 +848,13 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# add metrics to loggers
|
||||
if using_results_obj:
|
||||
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
|
||||
else:
|
||||
metrics_to_log = opt_closure_result.training_step_output.log_metrics
|
||||
batch_log_metrics.append(metrics_to_log)
|
||||
|
||||
# add metrics to progress bar
|
||||
if using_results_obj:
|
||||
step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics
|
||||
else:
|
||||
metrics_to_log = opt_closure_result.training_step_output.log_metrics
|
||||
step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
|
||||
|
||||
# track metrics
|
||||
batch_log_metrics.append(metrics_to_log)
|
||||
if len(step_pbar_metrics) > 0:
|
||||
self.add_progress_bar_metrics(step_pbar_metrics)
|
||||
|
||||
|
@ -1018,6 +1015,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
training_step_output_for_epoch_end = training_step_output
|
||||
is_result_obj = isinstance(training_step_output, Result)
|
||||
|
||||
# track batch size for weighted average
|
||||
if is_result_obj:
|
||||
training_step_output.track_batch_size(len(split_batch))
|
||||
|
||||
# don't allow EvalResult in the training_step
|
||||
if isinstance(training_step_output, EvalResult):
|
||||
raise MisconfigurationException('training_step cannot return EvalResult, '
|
||||
|
|
Loading…
Reference in New Issue