16 lines
487 B
Python
16 lines
487 B
Python
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from pytorch_lightning.metrics.functional.reduction import reduce
|
||
|
|
||
|
|
||
|
def test_reduce():
|
||
|
start_tensor = torch.rand(50, 40, 30)
|
||
|
|
||
|
assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
|
||
|
assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
|
||
|
assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)
|
||
|
|
||
|
with pytest.raises(ValueError):
|
||
|
reduce(start_tensor, 'error_reduction')
|