import torch def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: """ Reduces a given tensor by a given reduction method Args: to_reduce : the tensor, which shall be reduced reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') Return: reduced Tensor Raise: ValueError if an invalid reduction parameter was given """ if reduction == 'elementwise_mean': return torch.mean(to_reduce) if reduction == 'none': return to_reduce if reduction == 'sum': return torch.sum(to_reduce) raise ValueError('Reduction parameter unknown.')