lightning/tests/tests_data/datasets/test_iterable.py

610 lines
18 KiB
Python

import math
import sys
from collections import Counter
from functools import partial
from typing import Any, Dict
import lightning
import pytest
import torch
from lightning.data.datasets.iterable import (
DataLoader,
LightningIterableDataset,
_Chunk,
_Stateful,
_StatefulIterableDataset,
)
class Foo1:
def state_dict(self, returned_samples: int) -> Dict[str, Any]:
pass
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass
class Foo2:
def state_dict(self) -> Dict[str, Any]:
pass
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass
class Bar1:
pass
class Bar2:
def state_dict(self) -> Dict[str, Any]:
pass
@pytest.mark.parametrize(
("klass", "fullfilled"),
[
pytest.param(Foo1, True),
pytest.param(Foo2, True),
pytest.param(Bar1, False),
pytest.param(Bar2, False),
],
)
def test_serializable(klass, fullfilled):
assert isinstance(klass(), _Stateful) == fullfilled
class DummyIterableDataset(_StatefulIterableDataset):
def __init__(self, length: int):
super().__init__()
self.length = length
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return 0
class WrongDummySerializableIterableDataset1(DummyIterableDataset):
def state_dict(self):
return {"length": self.length, "index": self.index}
class WrongDummySerializableIterableDataset2(DummyIterableDataset):
def load_state_dict(self, state_dict):
self.length = state_dict.pop("length")
self.index = state_dict.pop("index")
class WorkingDummySerializableIterableDataset(
WrongDummySerializableIterableDataset1, WrongDummySerializableIterableDataset2
):
pass
@pytest.mark.parametrize(
("klass", "missing_method"),
[
pytest.param(WrongDummySerializableIterableDataset1, "load_state_dict"),
pytest.param(WrongDummySerializableIterableDataset2, "state_dict"),
],
)
def test_required_abstract_methods_serializable_dataset(klass, missing_method):
with pytest.raises(
TypeError,
match=f"Can't instantiate abstract class {klass.__name__} with abstract method.* {missing_method}",
):
klass(10)
def test_serialization_iterable_dataset():
dset = WorkingDummySerializableIterableDataset(10)
dset_iter = iter(dset)
assert dset_iter is dset
for i in range(10):
assert dset.state_dict() == {"length": 10, "index": i}
next(dset_iter)
assert dset.state_dict() == {"length": 10, "index": i + 1}
def test_iteration_serializable_iterable_dataset():
dset = WorkingDummySerializableIterableDataset(10)
i = 0
for _ in dset:
i = i + 1
assert i == 10
def test_resume_iterable_dataset():
dset1 = WorkingDummySerializableIterableDataset(10)
dset1_iter = iter(dset1)
for _ in range(5):
next(dset1_iter)
assert dset1.state_dict() == {"length": 10, "index": 5}
dset2 = WorkingDummySerializableIterableDataset(12)
dset2.load_state_dict(dset1.state_dict())
assert dset2.length == 10
assert dset2.index == 5
i = 0
for _ in dset2:
i = i + 1
assert i == 5
dset2.length = 12
for _ in dset2:
i = i + 1
assert i == 7
assert dset2.state_dict() == {"length": 12, "index": 12}
class WrongChunkedDataset1(LightningIterableDataset):
def load_chunk(self, curr_chunk: int):
return [(curr_chunk, i) for i in range(self._chunk_size)]
class WrongChunkedDataset2(LightningIterableDataset):
def load_sample_from_chunk(self, curr_chunk, curr_index):
return curr_chunk[curr_index]
class WorkingChunkedDataset(WrongChunkedDataset1, WrongChunkedDataset2):
pass
@pytest.mark.parametrize(
("klass", "missing_method"),
[
pytest.param(WrongChunkedDataset1, "load_sample_from_chunk"),
pytest.param(WrongChunkedDataset2, "load_chunk"),
],
)
def test_required_abstract_methods_chunked_dataset(klass, missing_method):
with pytest.raises(
TypeError,
match=f"Can't instantiate abstract class {klass.__name__} with abstract method.* {missing_method}",
):
klass([10], 10)
def test_chunked_dataset_iteration():
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=False, wrap=False)
curr_item = 0
for i, item in enumerate(dset):
assert item[0] == curr_item
assert item[1] == i % 2
curr_item += item[1]
# goes to 4 but increases again in last item
assert curr_item == 5
assert i == 9
@pytest.mark.parametrize("lazy_shuffle", [False, True])
def test_chunk_dataset_iteration_shuffle(lazy_shuffle):
dset = WorkingChunkedDataset(
list(range(5)),
chunk_size=2,
shuffle=True,
seed=12345,
wrap=False,
lazy_shuffle=lazy_shuffle,
)
counter = Counter()
series = []
unexpected_series = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
series_keys = []
unexpected_series_keys = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
for item, key in dset:
counter.update({item: 1})
series.append(item)
series_keys.append(key)
for val in counter.values():
assert val == 2
# with shuffling it can't be equal to ordered!
assert series != unexpected_series
assert series_keys != unexpected_series_keys
def test_chunked_dataset_wrap():
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=True, seed=12345, wrap=True)
dset_iter = iter(dset)
# dataset has length 10, so this wraps 2 times
for i in range(21):
_ = next(dset_iter)
def test_chunked_dataset_resume_and_reset():
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=False, wrap=False)
for i, item in enumerate(dset):
assert item[0] == 0
assert item[1] == i
if i == 1:
break
# Every iterator starts from scratch
for i, item in enumerate(dset):
assert item[0] == 0
assert item[1] == i
if i == 1:
break
# this would be set when we load from state dict
dset._start_index_sample = 1
for i, item in enumerate(dset):
assert item[0] == i
assert item[1] == (i + 1) % 2
if i == 1:
break
dset._start_index_chunk == 1
for i, item in enumerate(dset):
assert item[0] == 1
assert item[1] == (i + 1) % 2
if i == 1:
break
@pytest.mark.parametrize("shuffle", [False, True])
def test_chunked_dataset_serialization(shuffle):
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=shuffle, wrap=False)
assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0}
dset_iter = iter(dset)
assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0}
dset.load_state_dict(dset.state_dict(0, 0))
assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0}
dset_iter = iter(dset)
# throw away first few batches to alter internal state
for i in range(3):
next(dset_iter)
curr_state = dset.state_dict(3, 0)
original = [next(dset_iter) for _ in range(5)]
dset.load_state_dict(curr_state)
dset_iter = iter(dset)
after_loading = [next(dset_iter) for _ in range(5)]
# this isn't because we always skip to beginning of next chunk when loading and not already at beginning of chunk
assert original != after_loading
assert original[1:] == after_loading[:-1]
# this actually puts us already on beginning of a chunk, but we'll forward to beginning of next chunk,
# otherwise we'd two times resume from same checkpoint and assert equal behavior
dset.load_state_dict(curr_state)
_ = [next(dset_iter) for _ in range(2)]
new_curr_state = dset.state_dict(6, 0)
new_original = [next(dset_iter) for _ in range(3)]
dset.load_state_dict(new_curr_state)
new_after_loading = [next(dset_iter) for _ in range(3)]
# this is equal since we exactly stopped at beginning of new chunk
assert new_original == new_after_loading
class ChunkedTestDatasetDistributed(WorkingChunkedDataset):
def _apply_sharding(self):
super()._apply_sharding()
assert len(self._local_chunks) == self.expected_num_chunks
for i in range(1, len(self._local_chunks)):
assert self._local_chunks[i]._chunk_data - self._local_chunks[i - 1]._chunk_data == self.expected_step_width
def sharding_test(fabric: lightning.Fabric, num_workers):
dset = ChunkedTestDatasetDistributed(list(range(50)), 2, shuffle=False, wrap=False)
num_shards = max(1, num_workers) * fabric.world_size
# num_workers = 0 still has a single worker (the main process)
expected_num_chunks = 50 // num_shards
dset.expected_num_chunks = expected_num_chunks
dset.expected_step_width = fabric.world_size * max(num_workers, 1)
num_samples_per_rank = max(num_workers, 1) * 2 * expected_num_chunks
loader = torch.utils.data.DataLoader(dset, num_workers=num_workers)
for i, _ in enumerate(loader):
fabric.barrier()
assert i == num_samples_per_rank - 1
@pytest.mark.parametrize(
("num_workers", "world_size"),
[
pytest.param(0, 1),
pytest.param(
0,
2,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
1,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
2,
1,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
2,
2,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
],
)
def test_sharding(num_workers, world_size):
fabric = lightning.Fabric(accelerator="cpu", devices=world_size, strategy="ddp_spawn")
fabric.launch(partial(sharding_test, num_workers=num_workers))
def sharding_resume_test(fabric: lightning.Fabric, num_workers):
chunk_size = 2
dset = WorkingChunkedDataset(list(range(100)), chunk_size, shuffle=False, wrap=False)
loader = torch.utils.data.DataLoader(dset, num_workers=num_workers, shuffle=False)
num_shards = max(1, num_workers) * fabric.world_size
for i in [23, 37, 10, 20]:
curr_index = math.ceil(i / num_shards / chunk_size) * num_shards * chunk_size
next_chunk = math.ceil(curr_index / chunk_size)
curr_state = dset.state_dict(i, num_workers=num_workers)
assert curr_state == {"current_chunk": next_chunk, "current_sample_in_chunk": 0}
dset.load_state_dict(curr_state)
loader = torch.utils.data.DataLoader(dset, num_workers=num_workers, shuffle=False)
# calculate starting chunks
# next_chunk + fabric.global_rank * max(1,num_workers) determines the base offset for each rank
# i % chunk_size makes sure that workers are alternating
# e.g. w0 returns first element of first chunk then w1 returns first element of second chunk then w0 returns
# second element of first chunk etc.
# i // num_shards * num_shards progresses to next chunks
curr_worker_chunk = {
i: next_chunk
+ fabric.global_rank * max(1, num_workers)
+ i % max(1, num_workers)
+ i // (chunk_size * num_shards)
for i in range(max(1, num_workers))
}
curr_worker_chunk_elem = {i: 0 for i in range(max(1, num_workers))}
for i, batch in enumerate(loader):
curr_worker = i % max(1, num_workers)
assert batch[0] == curr_worker_chunk[curr_worker]
assert batch[1] == curr_worker_chunk_elem[curr_worker]
curr_worker_chunk_elem[curr_worker] += 1
if curr_worker_chunk_elem[curr_worker] == chunk_size:
curr_worker_chunk[curr_worker] += num_shards
curr_worker_chunk_elem[curr_worker] = 0
fabric.barrier()
@pytest.mark.parametrize(
("num_workers", "world_size"),
[
pytest.param(0, 1),
pytest.param(
0,
2,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
1,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
1,
2,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
2,
1,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
2,
2,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
],
)
def test_chunked_dataset_sharded_state_dict_resume(num_workers, world_size):
fabric = lightning.Fabric(accelerator="cpu", devices=world_size, strategy="ddp_spawn")
fabric.launch(partial(sharding_resume_test, num_workers=num_workers))
@pytest.mark.parametrize("chunk_size", [20, 30, 40])
@pytest.mark.parametrize("shuffle", [False, True])
@pytest.mark.parametrize("shuffle_seed", [None, 123])
@pytest.mark.parametrize("delayed_start", [False, True])
def test_chunk(chunk_size, shuffle, shuffle_seed, delayed_start):
data = list(range(chunk_size))
delayed_start_index = int(delayed_start) * (chunk_size - 10)
chunk = _Chunk(data, chunk_size=chunk_size, start_index=delayed_start_index)
linear_permutation = tuple(range(chunk_size))
assert chunk.index_permutations == linear_permutation
for i, index in enumerate(chunk):
assert index == delayed_start_index + i
assert chunk.chunk_size == chunk_size
if shuffle:
generator = torch.Generator().manual_seed(shuffle_seed) if shuffle_seed else None
chunk = chunk.shuffle(generator=generator)
old_permutation = chunk.index_permutations
assert old_permutation != linear_permutation
new_perm = []
for i, index in enumerate(chunk):
new_perm.append(index)
assert tuple(new_perm) == tuple([old_permutation[k] for k in range(delayed_start_index, chunk_size)])
assert len(new_perm) == chunk_size - delayed_start_index
if shuffle_seed:
chunk = chunk.shuffle(generator=generator.manual_seed(shuffle_seed))
assert chunk.index_permutations == old_permutation
assert chunk.chunk_size == chunk_size
class MyDataset(_StatefulIterableDataset):
def __init__(self, length):
self.length = length
self.samples = list(range(length))
self.curr_iter = 0
def __iter__(self):
for sample in self.samples[self.curr_iter :]:
yield sample
self.curr_iter += 1
def state_dict(self, returned_samples, num_workers):
return {"curr_iter": returned_samples, "num_workers": num_workers}
def load_state_dict(self, state_dict):
self.curr_iter = state_dict.pop("curr_iter")
@pytest.mark.parametrize("batch_size", [1, 2, 3])
@pytest.mark.parametrize(
"num_workers",
[
pytest.param(0),
pytest.param(
1,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
pytest.param(
2,
marks=pytest.mark.skipif(
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
),
),
],
)
@pytest.mark.parametrize("prefetch_factor", [1, 2, 3])
@pytest.mark.parametrize("length", [100, 101])
@pytest.mark.parametrize("num_batches", [1, 2, 7])
def test_resumable_loader(batch_size, num_workers, prefetch_factor, length, num_batches):
dset = MyDataset(length)
loader = DataLoader(
dset,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
)
loader_iter = iter(loader)
for i, batch in enumerate(loader_iter):
assert loader._get_batch_size(batch) == batch_size
if i == num_batches - 1:
break
state_dict = loader.state_dict()
assert state_dict["returned_samples"] == batch_size * num_batches
assert state_dict["dataset"] == {
"curr_iter": batch_size * num_batches,
"num_workers": num_workers,
}
state_dict["returned_samples"] += 1
state_dict["dataset"]["curr_iter"] += 1
loader.load_state_dict(state_dict)
assert loader.returned_samples == batch_size * num_batches + 1
assert loader.dataset.curr_iter == batch_size * num_batches + 1
def test_state_dict_error():
loader = DataLoader([1, 2, 3])
with pytest.raises(
TypeError,
match="The dataset has no method `state_dict` that accepts `returned_samples` and `num_workers`",
):
loader.state_dict()
def test_load_state_dict_error():
loader = DataLoader([1, 2, 3])
with pytest.raises(
TypeError,
match="The dataset has no method `load_state_dict` accepting a `state_dict`",
):
loader.load_state_dict({"returned_samples": 1, "dataset": {"some_key": "some_val"}})