lightning/tests/metrics/functional/test_image_gradients.py

110 lines
2.9 KiB
Python

import pytest
import torch
from pytorch_lightning.metrics.functional.image_gradients import image_gradients
def test_invalid_input_img_type():
"""Test Whether the module successfully handles invalid input data type"""
invalid_dummy_input = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
with pytest.raises(TypeError):
image_gradients(invalid_dummy_input)
def test_invalid_input_ndims():
"""
Test whether the module successfully handles invalid number of dimensions
of input tensor
"""
BATCH_SIZE = 1
HEIGHT = 5
WIDTH = 5
CHANNELS = 1
image = torch.arange(0, BATCH_SIZE * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32)
image = torch.reshape(image, (HEIGHT, WIDTH))
with pytest.raises(RuntimeError):
image_gradients(image)
def test_multi_batch_image_gradients():
"""Test whether the module correctly calculates gradients for known input
with non-unity batch size.Example input-output pair taken from TF's implementation of i
mage-gradients
"""
BATCH_SIZE = 5
HEIGHT = 5
WIDTH = 5
CHANNELS = 1
single_channel_img = torch.arange(0, 1 * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32)
single_channel_img = torch.reshape(single_channel_img, (CHANNELS, HEIGHT, WIDTH))
image = torch.stack([single_channel_img for _ in range(BATCH_SIZE)], dim=0)
true_dy = [
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[0., 0., 0., 0., 0.],
]
true_dx = [
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
]
true_dy = torch.Tensor(true_dy)
true_dx = torch.Tensor(true_dx)
dy, dx = image_gradients(image)
for batch_id in range(BATCH_SIZE):
assert torch.allclose(dy[batch_id, 0, :, :], true_dy)
assert dy.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH)
assert dx.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH)
def test_image_gradients():
"""Test whether the module correctly calculates gradients for known input.
Example input-output pair taken from TF's implementation of image-gradients
"""
BATCH_SIZE = 1
HEIGHT = 5
WIDTH = 5
CHANNELS = 1
image = torch.arange(0, BATCH_SIZE * HEIGHT * WIDTH * CHANNELS, dtype=torch.float32)
image = torch.reshape(image, (BATCH_SIZE, CHANNELS, HEIGHT, WIDTH))
true_dy = [
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[5., 5., 5., 5., 5.],
[0., 0., 0., 0., 0.],
]
true_dx = [
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
]
true_dy = torch.Tensor(true_dy)
true_dx = torch.Tensor(true_dx)
dy, dx = image_gradients(image)
assert torch.allclose(dy, true_dy), "dy fails test"
assert torch.allclose(dx, true_dx), "dx fails tests"