lightning/tests/tests_data/cache/test_serializer.py

116 lines
3.7 KiB
Python

# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from time import time
import numpy as np
import pytest
import torch
from lightning import seed_everything
from lightning.data.cache.serializers import (
_SERIALIZERS,
_TORCH_DTYPES_MAPPING,
IntSerializer,
PickleSerializer,
PILSerializer,
TensorSerializer,
)
from lightning_utilities.core.imports import RequirementCache
_PIL_AVAILABLE = RequirementCache("PIL")
def test_serializers():
assert list(_SERIALIZERS.keys()) == ["file", "pil", "int", "jpeg", "bytes", "tensor", "pickle"]
def test_int_serializer():
serializer = IntSerializer()
for i in range(100):
data, _ = serializer.serialize(i)
assert isinstance(data, bytes)
assert i == serializer.deserialize(data)
@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']")
@pytest.mark.parametrize("mode", ["I", "L", "RGB"])
def test_pil_serializer(mode):
serializer = PILSerializer()
from PIL import Image
np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32)
img = Image.fromarray(np_data).convert(mode)
data, _ = serializer.serialize(img)
assert isinstance(data, bytes)
deserialized_img = serializer.deserialize(data)
deserialized_img = deserialized_img.convert("I")
np_dec_data = np.asarray(deserialized_img, dtype=np.uint32)
assert isinstance(deserialized_img, Image.Image)
# Validate data content
assert np.array_equal(np_data, np_dec_data)
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_tensor_serializer():
seed_everything(42)
serializer_tensor = TensorSerializer()
serializer_pickle = PickleSerializer()
ratio_times = []
ratio_bytes = []
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
for dtype in _TORCH_DTYPES_MAPPING.values():
for shape in shapes:
# Not serializable for some reasons
if dtype in [torch.bfloat16]:
continue
tensor = torch.ones(shape, dtype=dtype)
t0 = time()
data, _ = serializer_tensor.serialize(tensor)
deserialized_tensor = serializer_tensor.deserialize(data)
tensor_time = time() - t0
tensor_bytes = len(data)
assert deserialized_tensor.dtype == dtype
assert torch.equal(tensor, deserialized_tensor)
t1 = time()
data, _ = serializer_pickle.serialize(tensor)
deserialized_tensor = serializer_pickle.deserialize(data)
pickle_time = time() - t1
pickle_bytes = len(data)
assert deserialized_tensor.dtype == dtype
assert torch.equal(tensor, deserialized_tensor)
ratio_times.append(pickle_time / tensor_time)
ratio_bytes.append(pickle_bytes / tensor_bytes)
assert np.mean(ratio_times) > 3.5
assert np.mean(ratio_bytes) > 2
def test_assert_bfloat16_tensor_serializer():
serializer = TensorSerializer()
tensor = torch.ones((10,), dtype=torch.bfloat16)
with pytest.raises(TypeError, match="Got unsupported ScalarType BFloat16"):
serializer.serialize(tensor)