Add numpy support for the StreamingDataset 1/2 (#19050)

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
thomas chaton 2023-11-22 18:00:15 +00:00 committed by GitHub
parent 1073276a58
commit 7eca9c1642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 1 deletions

View File

@ -14,6 +14,7 @@
import os
from pathlib import Path
import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
@ -52,4 +53,7 @@ _TORCH_DTYPES_MAPPING = {
19: torch.bool,
}
_NUMPY_SCTYPES = [v for values in np.sctypes.values() for v in values]
_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)}
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"

View File

@ -22,7 +22,7 @@ import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning.data.streaming.constants import _TORCH_DTYPES_MAPPING
from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
@ -200,6 +200,61 @@ class NoHeaderTensorSerializer(Serializer):
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) == 1
class NumpySerializer(Serializer):
"""The NumpySerializer serialize and deserialize numpy to and from bytes."""
def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice = self._dtype_to_indice[item.dtype]
data = [np.uint32(dtype_indice).tobytes()]
data.append(np.uint32(len(item.shape)).tobytes())
for dim in item.shape:
data.append(np.uint32(dim).tobytes())
data.append(item.tobytes(order="C"))
return b"".join(data), None
def deserialize(self, data: bytes) -> np.ndarray:
dtype_indice = np.frombuffer(data[0:4], np.uint32).item()
dtype = _NUMPY_DTYPES_MAPPING[dtype_indice]
shape_size = np.frombuffer(data[4:8], np.uint32).item()
shape = []
for shape_idx in range(shape_size):
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
if tensor.shape == shape:
return tensor
return np.reshape(tensor, shape)
def can_serialize(self, item: np.ndarray) -> bool:
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) > 1
class NoHeaderNumpySerializer(Serializer):
"""The NoHeaderNumpySerializer serialize and deserialize numpy to and from bytes."""
def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
self._dtype: Optional[np.dtype] = None
def setup(self, data_format: str) -> None:
self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])]
def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice: int = self._dtype_to_indice[item.dtype]
return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}"
def deserialize(self, data: bytes) -> np.ndarray:
assert self._dtype
return np.frombuffer(data, dtype=self._dtype)
def can_serialize(self, item: np.ndarray) -> bool:
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) == 1
class PickleSerializer(Serializer):
"""The PickleSerializer serialize and deserialize python objects to and from bytes."""
@ -263,6 +318,8 @@ _SERIALIZERS = OrderedDict(
"int": IntSerializer(),
"jpeg": JPEGSerializer(),
"bytes": BytesSerializer(),
"no_header_numpy": NoHeaderNumpySerializer(),
"numpy": NumpySerializer(),
"no_header_tensor": NoHeaderTensorSerializer(),
"tensor": TensorSerializer(),
"pickle": PickleSerializer(),

View File

@ -21,11 +21,14 @@ import torch
from lightning import seed_everything
from lightning.data.streaming.serializers import (
_AV_AVAILABLE,
_NUMPY_DTYPES_MAPPING,
_SERIALIZERS,
_TORCH_DTYPES_MAPPING,
_TORCH_VISION_AVAILABLE,
IntSerializer,
NoHeaderNumpySerializer,
NoHeaderTensorSerializer,
NumpySerializer,
PickleSerializer,
PILSerializer,
TensorSerializer,
@ -44,6 +47,8 @@ def test_serializers():
"int",
"jpeg",
"bytes",
"no_header_numpy",
"numpy",
"no_header_tensor",
"tensor",
"pickle",
@ -124,6 +129,25 @@ def test_tensor_serializer():
assert np.mean(ratio_bytes) > 2
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_numpy_serializer():
seed_everything(42)
serializer_tensor = NumpySerializer()
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
for dtype in _NUMPY_DTYPES_MAPPING.values():
# Those types aren't supported
if dtype.name in ["object", "bytes", "str", "void"]:
continue
for shape in shapes:
tensor = np.ones(shape, dtype=dtype)
data, _ = serializer_tensor.serialize(tensor)
deserialized_tensor = serializer_tensor.deserialize(data)
assert deserialized_tensor.dtype == dtype
np.testing.assert_equal(tensor, deserialized_tensor)
def test_assert_bfloat16_tensor_serializer():
serializer = TensorSerializer()
tensor = torch.ones((10,), dtype=torch.bfloat16)
@ -143,6 +167,19 @@ def test_assert_no_header_tensor_serializer():
assert torch.equal(t, new_t)
def test_assert_no_header_numpy_serializer():
serializer = NoHeaderNumpySerializer()
t = np.ones((10,))
assert serializer.can_serialize(t)
data, name = serializer.serialize(t)
assert name == "no_header_numpy:10"
assert serializer._dtype is None
serializer.setup(name)
assert serializer._dtype == np.dtype("float64")
new_t = serializer.deserialize(data)
np.testing.assert_equal(t, new_t)
@pytest.mark.skipif(
condition=not _TORCH_VISION_AVAILABLE or not _AV_AVAILABLE, reason="Requires: ['torchvision', 'av']"
)