lightning/pytorch_lightning/metrics/functional/reduction.py

25 lines
675 B
Python

import torch
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
Reduces a given tensor by a given reduction method
Args:
to_reduce : the tensor, which shall be reduced
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
Return:
reduced Tensor
Raise:
ValueError if an invalid reduction parameter was given
"""
if reduction == 'elementwise_mean':
return torch.mean(to_reduce)
if reduction == 'none':
return to_reduce
if reduction == 'sum':
return torch.sum(to_reduce)
raise ValueError('Reduction parameter unknown.')