Add numpy support for the StreamingDataset 1/2 (#19050)
Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
This commit is contained in:
parent
1073276a58
commit
7eca9c1642
|
@ -14,6 +14,7 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
|
@ -52,4 +53,7 @@ _TORCH_DTYPES_MAPPING = {
|
|||
19: torch.bool,
|
||||
}
|
||||
|
||||
_NUMPY_SCTYPES = [v for values in np.sctypes.values() for v in values]
|
||||
_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)}
|
||||
|
||||
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
|
||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
|||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
from lightning.data.streaming.constants import _TORCH_DTYPES_MAPPING
|
||||
from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
|
||||
|
||||
_PIL_AVAILABLE = RequirementCache("PIL")
|
||||
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
|
||||
|
@ -200,6 +200,61 @@ class NoHeaderTensorSerializer(Serializer):
|
|||
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) == 1
|
||||
|
||||
|
||||
class NumpySerializer(Serializer):
|
||||
"""The NumpySerializer serialize and deserialize numpy to and from bytes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
|
||||
|
||||
def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
|
||||
dtype_indice = self._dtype_to_indice[item.dtype]
|
||||
data = [np.uint32(dtype_indice).tobytes()]
|
||||
data.append(np.uint32(len(item.shape)).tobytes())
|
||||
for dim in item.shape:
|
||||
data.append(np.uint32(dim).tobytes())
|
||||
data.append(item.tobytes(order="C"))
|
||||
return b"".join(data), None
|
||||
|
||||
def deserialize(self, data: bytes) -> np.ndarray:
|
||||
dtype_indice = np.frombuffer(data[0:4], np.uint32).item()
|
||||
dtype = _NUMPY_DTYPES_MAPPING[dtype_indice]
|
||||
shape_size = np.frombuffer(data[4:8], np.uint32).item()
|
||||
shape = []
|
||||
for shape_idx in range(shape_size):
|
||||
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
|
||||
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
|
||||
if tensor.shape == shape:
|
||||
return tensor
|
||||
return np.reshape(tensor, shape)
|
||||
|
||||
def can_serialize(self, item: np.ndarray) -> bool:
|
||||
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) > 1
|
||||
|
||||
|
||||
class NoHeaderNumpySerializer(Serializer):
|
||||
"""The NoHeaderNumpySerializer serialize and deserialize numpy to and from bytes."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
|
||||
self._dtype: Optional[np.dtype] = None
|
||||
|
||||
def setup(self, data_format: str) -> None:
|
||||
self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])]
|
||||
|
||||
def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
|
||||
dtype_indice: int = self._dtype_to_indice[item.dtype]
|
||||
return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}"
|
||||
|
||||
def deserialize(self, data: bytes) -> np.ndarray:
|
||||
assert self._dtype
|
||||
return np.frombuffer(data, dtype=self._dtype)
|
||||
|
||||
def can_serialize(self, item: np.ndarray) -> bool:
|
||||
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) == 1
|
||||
|
||||
|
||||
class PickleSerializer(Serializer):
|
||||
"""The PickleSerializer serialize and deserialize python objects to and from bytes."""
|
||||
|
||||
|
@ -263,6 +318,8 @@ _SERIALIZERS = OrderedDict(
|
|||
"int": IntSerializer(),
|
||||
"jpeg": JPEGSerializer(),
|
||||
"bytes": BytesSerializer(),
|
||||
"no_header_numpy": NoHeaderNumpySerializer(),
|
||||
"numpy": NumpySerializer(),
|
||||
"no_header_tensor": NoHeaderTensorSerializer(),
|
||||
"tensor": TensorSerializer(),
|
||||
"pickle": PickleSerializer(),
|
||||
|
|
|
@ -21,11 +21,14 @@ import torch
|
|||
from lightning import seed_everything
|
||||
from lightning.data.streaming.serializers import (
|
||||
_AV_AVAILABLE,
|
||||
_NUMPY_DTYPES_MAPPING,
|
||||
_SERIALIZERS,
|
||||
_TORCH_DTYPES_MAPPING,
|
||||
_TORCH_VISION_AVAILABLE,
|
||||
IntSerializer,
|
||||
NoHeaderNumpySerializer,
|
||||
NoHeaderTensorSerializer,
|
||||
NumpySerializer,
|
||||
PickleSerializer,
|
||||
PILSerializer,
|
||||
TensorSerializer,
|
||||
|
@ -44,6 +47,8 @@ def test_serializers():
|
|||
"int",
|
||||
"jpeg",
|
||||
"bytes",
|
||||
"no_header_numpy",
|
||||
"numpy",
|
||||
"no_header_tensor",
|
||||
"tensor",
|
||||
"pickle",
|
||||
|
@ -124,6 +129,25 @@ def test_tensor_serializer():
|
|||
assert np.mean(ratio_bytes) > 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
|
||||
def test_numpy_serializer():
|
||||
seed_everything(42)
|
||||
|
||||
serializer_tensor = NumpySerializer()
|
||||
|
||||
shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
|
||||
for dtype in _NUMPY_DTYPES_MAPPING.values():
|
||||
# Those types aren't supported
|
||||
if dtype.name in ["object", "bytes", "str", "void"]:
|
||||
continue
|
||||
for shape in shapes:
|
||||
tensor = np.ones(shape, dtype=dtype)
|
||||
data, _ = serializer_tensor.serialize(tensor)
|
||||
deserialized_tensor = serializer_tensor.deserialize(data)
|
||||
assert deserialized_tensor.dtype == dtype
|
||||
np.testing.assert_equal(tensor, deserialized_tensor)
|
||||
|
||||
|
||||
def test_assert_bfloat16_tensor_serializer():
|
||||
serializer = TensorSerializer()
|
||||
tensor = torch.ones((10,), dtype=torch.bfloat16)
|
||||
|
@ -143,6 +167,19 @@ def test_assert_no_header_tensor_serializer():
|
|||
assert torch.equal(t, new_t)
|
||||
|
||||
|
||||
def test_assert_no_header_numpy_serializer():
|
||||
serializer = NoHeaderNumpySerializer()
|
||||
t = np.ones((10,))
|
||||
assert serializer.can_serialize(t)
|
||||
data, name = serializer.serialize(t)
|
||||
assert name == "no_header_numpy:10"
|
||||
assert serializer._dtype is None
|
||||
serializer.setup(name)
|
||||
assert serializer._dtype == np.dtype("float64")
|
||||
new_t = serializer.deserialize(data)
|
||||
np.testing.assert_equal(t, new_t)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
condition=not _TORCH_VISION_AVAILABLE or not _AV_AVAILABLE, reason="Requires: ['torchvision', 'av']"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue