202 lines
6.4 KiB
Python
202 lines
6.4 KiB
Python
# Copyright The Lightning AI team.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
import sys
|
|
from time import time
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from lightning import seed_everything
|
|
from lightning.data.streaming.serializers import (
|
|
_AV_AVAILABLE,
|
|
_NUMPY_DTYPES_MAPPING,
|
|
_SERIALIZERS,
|
|
_TORCH_DTYPES_MAPPING,
|
|
_TORCH_VISION_AVAILABLE,
|
|
IntSerializer,
|
|
NoHeaderNumpySerializer,
|
|
NoHeaderTensorSerializer,
|
|
NumpySerializer,
|
|
PickleSerializer,
|
|
PILSerializer,
|
|
TensorSerializer,
|
|
VideoSerializer,
|
|
)
|
|
from lightning_utilities.core.imports import RequirementCache
|
|
|
|
_PIL_AVAILABLE = RequirementCache("PIL")
|
|
|
|
|
|
def test_serializers():
|
|
assert list(_SERIALIZERS.keys()) == [
|
|
"video",
|
|
"file",
|
|
"pil",
|
|
"int",
|
|
"jpeg",
|
|
"bytes",
|
|
"no_header_numpy",
|
|
"numpy",
|
|
"no_header_tensor",
|
|
"tensor",
|
|
"pickle",
|
|
]
|
|
|
|
|
|
def test_int_serializer():
|
|
serializer = IntSerializer()
|
|
|
|
for i in range(100):
|
|
data, _ = serializer.serialize(i)
|
|
assert isinstance(data, bytes)
|
|
assert i == serializer.deserialize(data)
|
|
|
|
|
|
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
|
|
@pytest.mark.parametrize("mode", ["I", "L", "RGB"])
|
|
def test_pil_serializer(mode):
|
|
serializer = PILSerializer()
|
|
|
|
from PIL import Image
|
|
|
|
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
|
|
img = Image.fromarray(np_data).convert(mode)
|
|
|
|
data, _ = serializer.serialize(img)
|
|
assert isinstance(data, bytes)
|
|
|
|
deserialized_img = serializer.deserialize(data)
|
|
deserialized_img = deserialized_img.convert("I")
|
|
np_dec_data = np.asarray(deserialized_img, dtype=np.uint32)
|
|
assert isinstance(deserialized_img, Image.Image)
|
|
|
|
# Validate data content
|
|
assert np.array_equal(np_data, np_dec_data)
|
|
|
|
|
|
@pytest.mark.flaky(reruns=3)
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
|
|
def test_tensor_serializer():
|
|
seed_everything(42)
|
|
|
|
serializer_tensor = TensorSerializer()
|
|
serializer_pickle = PickleSerializer()
|
|
|
|
ratio_times = []
|
|
ratio_bytes = []
|
|
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
|
|
for dtype in _TORCH_DTYPES_MAPPING.values():
|
|
for shape in shapes:
|
|
# Not serializable for some reasons
|
|
if dtype in [torch.bfloat16]:
|
|
continue
|
|
tensor = torch.ones(shape, dtype=dtype)
|
|
|
|
t0 = time()
|
|
data, _ = serializer_tensor.serialize(tensor)
|
|
deserialized_tensor = serializer_tensor.deserialize(data)
|
|
tensor_time = time() - t0
|
|
tensor_bytes = len(data)
|
|
|
|
assert deserialized_tensor.dtype == dtype
|
|
assert torch.equal(tensor, deserialized_tensor)
|
|
|
|
t1 = time()
|
|
data, _ = serializer_pickle.serialize(tensor)
|
|
deserialized_tensor = serializer_pickle.deserialize(data)
|
|
pickle_time = time() - t1
|
|
pickle_bytes = len(data)
|
|
|
|
assert deserialized_tensor.dtype == dtype
|
|
assert torch.equal(tensor, deserialized_tensor)
|
|
|
|
ratio_times.append(pickle_time / tensor_time)
|
|
ratio_bytes.append(pickle_bytes / tensor_bytes)
|
|
|
|
assert np.mean(ratio_times) > 1.6
|
|
assert np.mean(ratio_bytes) > 2
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
|
|
def test_numpy_serializer():
|
|
seed_everything(42)
|
|
|
|
serializer_tensor = NumpySerializer()
|
|
|
|
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
|
|
for dtype in _NUMPY_DTYPES_MAPPING.values():
|
|
# Those types aren't supported
|
|
if dtype.name in ["object", "bytes", "str", "void"]:
|
|
continue
|
|
for shape in shapes:
|
|
tensor = np.ones(shape, dtype=dtype)
|
|
data, _ = serializer_tensor.serialize(tensor)
|
|
deserialized_tensor = serializer_tensor.deserialize(data)
|
|
assert deserialized_tensor.dtype == dtype
|
|
np.testing.assert_equal(tensor, deserialized_tensor)
|
|
|
|
|
|
def test_assert_bfloat16_tensor_serializer():
|
|
serializer = TensorSerializer()
|
|
tensor = torch.ones((10,), dtype=torch.bfloat16)
|
|
with pytest.raises(TypeError, match="Got unsupported ScalarType BFloat16"):
|
|
serializer.serialize(tensor)
|
|
|
|
|
|
def test_assert_no_header_tensor_serializer():
|
|
serializer = NoHeaderTensorSerializer()
|
|
t = torch.ones((10,))
|
|
data, name = serializer.serialize(t)
|
|
assert name == "no_header_tensor:1"
|
|
assert serializer._dtype is None
|
|
serializer.setup(name)
|
|
assert serializer._dtype == torch.float32
|
|
new_t = serializer.deserialize(data)
|
|
assert torch.equal(t, new_t)
|
|
|
|
|
|
def test_assert_no_header_numpy_serializer():
|
|
serializer = NoHeaderNumpySerializer()
|
|
t = np.ones((10,))
|
|
assert serializer.can_serialize(t)
|
|
data, name = serializer.serialize(t)
|
|
assert name == "no_header_numpy:10"
|
|
assert serializer._dtype is None
|
|
serializer.setup(name)
|
|
assert serializer._dtype == np.dtype("float64")
|
|
new_t = serializer.deserialize(data)
|
|
np.testing.assert_equal(t, new_t)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
condition=not _TORCH_VISION_AVAILABLE or not _AV_AVAILABLE, reason="Requires: ['torchvision', 'av']"
|
|
)
|
|
def test_wav_deserialization(tmpdir):
|
|
from torch.hub import download_url_to_file
|
|
|
|
video_file = os.path.join(tmpdir, "video.wav")
|
|
key = "tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa E501
|
|
download_url_to_file(f"https://download.pytorch.org/torchaudio/{key}", video_file)
|
|
|
|
serializer = VideoSerializer()
|
|
assert serializer.can_serialize(video_file)
|
|
data, name = serializer.serialize(video_file)
|
|
assert len(data) / 1024 / 1024 == 0.10380172729492188
|
|
assert name == "wav"
|
|
vframes, aframes, info = serializer.deserialize(data)
|
|
assert vframes.shape == torch.Size([0, 1, 1, 3])
|
|
assert aframes.shape == torch.Size([1, 54400])
|
|
assert info == {"audio_fps": 16000}
|