lightning/tests/tests_data/streaming/test_serializer.py

202 lines
6.4 KiB
Python

# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from time import time
import numpy as np
import pytest
import torch
from lightning import seed_everything
from lightning.data.streaming.serializers import (
_AV_AVAILABLE,
_NUMPY_DTYPES_MAPPING,
_SERIALIZERS,
_TORCH_DTYPES_MAPPING,
_TORCH_VISION_AVAILABLE,
IntSerializer,
NoHeaderNumpySerializer,
NoHeaderTensorSerializer,
NumpySerializer,
PickleSerializer,
PILSerializer,
TensorSerializer,
VideoSerializer,
)
from lightning_utilities.core.imports import RequirementCache
_PIL_AVAILABLE = RequirementCache("PIL")
def test_serializers():
assert list(_SERIALIZERS.keys()) == [
"video",
"file",
"pil",
"int",
"jpeg",
"bytes",
"no_header_numpy",
"numpy",
"no_header_tensor",
"tensor",
"pickle",
]
def test_int_serializer():
serializer = IntSerializer()
for i in range(100):
data, _ = serializer.serialize(i)
assert isinstance(data, bytes)
assert i == serializer.deserialize(data)
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
@pytest.mark.parametrize("mode", ["I", "L", "RGB"])
def test_pil_serializer(mode):
serializer = PILSerializer()
from PIL import Image
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
img = Image.fromarray(np_data).convert(mode)
data, _ = serializer.serialize(img)
assert isinstance(data, bytes)
deserialized_img = serializer.deserialize(data)
deserialized_img = deserialized_img.convert("I")
np_dec_data = np.asarray(deserialized_img, dtype=np.uint32)
assert isinstance(deserialized_img, Image.Image)
# Validate data content
assert np.array_equal(np_data, np_dec_data)
@pytest.mark.flaky(reruns=3)
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_tensor_serializer():
seed_everything(42)
serializer_tensor = TensorSerializer()
serializer_pickle = PickleSerializer()
ratio_times = []
ratio_bytes = []
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
for dtype in _TORCH_DTYPES_MAPPING.values():
for shape in shapes:
# Not serializable for some reasons
if dtype in [torch.bfloat16]:
continue
tensor = torch.ones(shape, dtype=dtype)
t0 = time()
data, _ = serializer_tensor.serialize(tensor)
deserialized_tensor = serializer_tensor.deserialize(data)
tensor_time = time() - t0
tensor_bytes = len(data)
assert deserialized_tensor.dtype == dtype
assert torch.equal(tensor, deserialized_tensor)
t1 = time()
data, _ = serializer_pickle.serialize(tensor)
deserialized_tensor = serializer_pickle.deserialize(data)
pickle_time = time() - t1
pickle_bytes = len(data)
assert deserialized_tensor.dtype == dtype
assert torch.equal(tensor, deserialized_tensor)
ratio_times.append(pickle_time / tensor_time)
ratio_bytes.append(pickle_bytes / tensor_bytes)
assert np.mean(ratio_times) > 1.6
assert np.mean(ratio_bytes) > 2
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_numpy_serializer():
seed_everything(42)
serializer_tensor = NumpySerializer()
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
for dtype in _NUMPY_DTYPES_MAPPING.values():
# Those types aren't supported
if dtype.name in ["object", "bytes", "str", "void"]:
continue
for shape in shapes:
tensor = np.ones(shape, dtype=dtype)
data, _ = serializer_tensor.serialize(tensor)
deserialized_tensor = serializer_tensor.deserialize(data)
assert deserialized_tensor.dtype == dtype
np.testing.assert_equal(tensor, deserialized_tensor)
def test_assert_bfloat16_tensor_serializer():
serializer = TensorSerializer()
tensor = torch.ones((10,), dtype=torch.bfloat16)
with pytest.raises(TypeError, match="Got unsupported ScalarType BFloat16"):
serializer.serialize(tensor)
def test_assert_no_header_tensor_serializer():
serializer = NoHeaderTensorSerializer()
t = torch.ones((10,))
data, name = serializer.serialize(t)
assert name == "no_header_tensor:1"
assert serializer._dtype is None
serializer.setup(name)
assert serializer._dtype == torch.float32
new_t = serializer.deserialize(data)
assert torch.equal(t, new_t)
def test_assert_no_header_numpy_serializer():
serializer = NoHeaderNumpySerializer()
t = np.ones((10,))
assert serializer.can_serialize(t)
data, name = serializer.serialize(t)
assert name == "no_header_numpy:10"
assert serializer._dtype is None
serializer.setup(name)
assert serializer._dtype == np.dtype("float64")
new_t = serializer.deserialize(data)
np.testing.assert_equal(t, new_t)
@pytest.mark.skipif(
condition=not _TORCH_VISION_AVAILABLE or not _AV_AVAILABLE, reason="Requires: ['torchvision', 'av']"
)
def test_wav_deserialization(tmpdir):
from torch.hub import download_url_to_file
video_file = os.path.join(tmpdir, "video.wav")
key = "tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa E501
download_url_to_file(f"https://download.pytorch.org/torchaudio/{key}", video_file)
serializer = VideoSerializer()
assert serializer.can_serialize(video_file)
data, name = serializer.serialize(video_file)
assert len(data) / 1024 / 1024 == 0.10380172729492188
assert name == "wav"
vframes, aframes, info = serializer.deserialize(data)
assert vframes.shape == torch.Size([0, 1, 1, 3])
assert aframes.shape == torch.Size([1, 54400])
assert info == {"audio_fps": 16000}