25 lines
675 B
Python
25 lines
675 B
Python
|
import torch
|
||
|
|
||
|
|
||
|
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
||
|
"""
|
||
|
Reduces a given tensor by a given reduction method
|
||
|
|
||
|
Args:
|
||
|
to_reduce : the tensor, which shall be reduced
|
||
|
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
|
||
|
|
||
|
Return:
|
||
|
reduced Tensor
|
||
|
|
||
|
Raise:
|
||
|
ValueError if an invalid reduction parameter was given
|
||
|
"""
|
||
|
if reduction == 'elementwise_mean':
|
||
|
return torch.mean(to_reduce)
|
||
|
if reduction == 'none':
|
||
|
return to_reduce
|
||
|
if reduction == 'sum':
|
||
|
return torch.sum(to_reduce)
|
||
|
raise ValueError('Reduction parameter unknown.')
|