610 lines
18 KiB
Python
610 lines
18 KiB
Python
import math
|
|
import sys
|
|
from collections import Counter
|
|
from functools import partial
|
|
from typing import Any, Dict
|
|
|
|
import lightning
|
|
import pytest
|
|
import torch
|
|
from lightning.data.datasets.iterable import (
|
|
DataLoader,
|
|
LightningIterableDataset,
|
|
_Chunk,
|
|
_Stateful,
|
|
_StatefulIterableDataset,
|
|
)
|
|
|
|
|
|
class Foo1:
|
|
def state_dict(self, returned_samples: int) -> Dict[str, Any]:
|
|
pass
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
|
pass
|
|
|
|
|
|
class Foo2:
|
|
def state_dict(self) -> Dict[str, Any]:
|
|
pass
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
|
pass
|
|
|
|
|
|
class Bar1:
|
|
pass
|
|
|
|
|
|
class Bar2:
|
|
def state_dict(self) -> Dict[str, Any]:
|
|
pass
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("klass", "fullfilled"),
|
|
[
|
|
pytest.param(Foo1, True),
|
|
pytest.param(Foo2, True),
|
|
pytest.param(Bar1, False),
|
|
pytest.param(Bar2, False),
|
|
],
|
|
)
|
|
def test_serializable(klass, fullfilled):
|
|
assert isinstance(klass(), _Stateful) == fullfilled
|
|
|
|
|
|
class DummyIterableDataset(_StatefulIterableDataset):
|
|
def __init__(self, length: int):
|
|
super().__init__()
|
|
self.length = length
|
|
self.index = 0
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.index >= self.length:
|
|
raise StopIteration
|
|
|
|
self.index += 1
|
|
return 0
|
|
|
|
|
|
class WrongDummySerializableIterableDataset1(DummyIterableDataset):
|
|
def state_dict(self):
|
|
return {"length": self.length, "index": self.index}
|
|
|
|
|
|
class WrongDummySerializableIterableDataset2(DummyIterableDataset):
|
|
def load_state_dict(self, state_dict):
|
|
self.length = state_dict.pop("length")
|
|
self.index = state_dict.pop("index")
|
|
|
|
|
|
class WorkingDummySerializableIterableDataset(
|
|
WrongDummySerializableIterableDataset1, WrongDummySerializableIterableDataset2
|
|
):
|
|
pass
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("klass", "missing_method"),
|
|
[
|
|
pytest.param(WrongDummySerializableIterableDataset1, "load_state_dict"),
|
|
pytest.param(WrongDummySerializableIterableDataset2, "state_dict"),
|
|
],
|
|
)
|
|
def test_required_abstract_methods_serializable_dataset(klass, missing_method):
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=f"Can't instantiate abstract class {klass.__name__} with abstract method.* {missing_method}",
|
|
):
|
|
klass(10)
|
|
|
|
|
|
def test_serialization_iterable_dataset():
|
|
dset = WorkingDummySerializableIterableDataset(10)
|
|
|
|
dset_iter = iter(dset)
|
|
|
|
assert dset_iter is dset
|
|
|
|
for i in range(10):
|
|
assert dset.state_dict() == {"length": 10, "index": i}
|
|
next(dset_iter)
|
|
assert dset.state_dict() == {"length": 10, "index": i + 1}
|
|
|
|
|
|
def test_iteration_serializable_iterable_dataset():
|
|
dset = WorkingDummySerializableIterableDataset(10)
|
|
|
|
i = 0
|
|
|
|
for _ in dset:
|
|
i = i + 1
|
|
|
|
assert i == 10
|
|
|
|
|
|
def test_resume_iterable_dataset():
|
|
dset1 = WorkingDummySerializableIterableDataset(10)
|
|
dset1_iter = iter(dset1)
|
|
|
|
for _ in range(5):
|
|
next(dset1_iter)
|
|
|
|
assert dset1.state_dict() == {"length": 10, "index": 5}
|
|
|
|
dset2 = WorkingDummySerializableIterableDataset(12)
|
|
dset2.load_state_dict(dset1.state_dict())
|
|
|
|
assert dset2.length == 10
|
|
assert dset2.index == 5
|
|
|
|
i = 0
|
|
for _ in dset2:
|
|
i = i + 1
|
|
|
|
assert i == 5
|
|
|
|
dset2.length = 12
|
|
for _ in dset2:
|
|
i = i + 1
|
|
|
|
assert i == 7
|
|
assert dset2.state_dict() == {"length": 12, "index": 12}
|
|
|
|
|
|
class WrongChunkedDataset1(LightningIterableDataset):
|
|
def load_chunk(self, curr_chunk: int):
|
|
return [(curr_chunk, i) for i in range(self._chunk_size)]
|
|
|
|
|
|
class WrongChunkedDataset2(LightningIterableDataset):
|
|
def load_sample_from_chunk(self, curr_chunk, curr_index):
|
|
return curr_chunk[curr_index]
|
|
|
|
|
|
class WorkingChunkedDataset(WrongChunkedDataset1, WrongChunkedDataset2):
|
|
pass
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("klass", "missing_method"),
|
|
[
|
|
pytest.param(WrongChunkedDataset1, "load_sample_from_chunk"),
|
|
pytest.param(WrongChunkedDataset2, "load_chunk"),
|
|
],
|
|
)
|
|
def test_required_abstract_methods_chunked_dataset(klass, missing_method):
|
|
with pytest.raises(
|
|
TypeError,
|
|
match=f"Can't instantiate abstract class {klass.__name__} with abstract method.* {missing_method}",
|
|
):
|
|
klass([10], 10)
|
|
|
|
|
|
def test_chunked_dataset_iteration():
|
|
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=False, wrap=False)
|
|
|
|
curr_item = 0
|
|
for i, item in enumerate(dset):
|
|
assert item[0] == curr_item
|
|
assert item[1] == i % 2
|
|
curr_item += item[1]
|
|
|
|
# goes to 4 but increases again in last item
|
|
assert curr_item == 5
|
|
assert i == 9
|
|
|
|
|
|
@pytest.mark.parametrize("lazy_shuffle", [False, True])
|
|
def test_chunk_dataset_iteration_shuffle(lazy_shuffle):
|
|
dset = WorkingChunkedDataset(
|
|
list(range(5)),
|
|
chunk_size=2,
|
|
shuffle=True,
|
|
seed=12345,
|
|
wrap=False,
|
|
lazy_shuffle=lazy_shuffle,
|
|
)
|
|
counter = Counter()
|
|
|
|
series = []
|
|
unexpected_series = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
|
|
|
|
series_keys = []
|
|
unexpected_series_keys = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
|
|
|
for item, key in dset:
|
|
counter.update({item: 1})
|
|
series.append(item)
|
|
series_keys.append(key)
|
|
|
|
for val in counter.values():
|
|
assert val == 2
|
|
|
|
# with shuffling it can't be equal to ordered!
|
|
assert series != unexpected_series
|
|
assert series_keys != unexpected_series_keys
|
|
|
|
|
|
def test_chunked_dataset_wrap():
|
|
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=True, seed=12345, wrap=True)
|
|
|
|
dset_iter = iter(dset)
|
|
|
|
# dataset has length 10, so this wraps 2 times
|
|
for i in range(21):
|
|
_ = next(dset_iter)
|
|
|
|
|
|
def test_chunked_dataset_resume_and_reset():
|
|
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=False, wrap=False)
|
|
|
|
for i, item in enumerate(dset):
|
|
assert item[0] == 0
|
|
assert item[1] == i
|
|
if i == 1:
|
|
break
|
|
|
|
# Every iterator starts from scratch
|
|
for i, item in enumerate(dset):
|
|
assert item[0] == 0
|
|
assert item[1] == i
|
|
if i == 1:
|
|
break
|
|
|
|
# this would be set when we load from state dict
|
|
dset._start_index_sample = 1
|
|
for i, item in enumerate(dset):
|
|
assert item[0] == i
|
|
assert item[1] == (i + 1) % 2
|
|
if i == 1:
|
|
break
|
|
|
|
dset._start_index_chunk == 1
|
|
for i, item in enumerate(dset):
|
|
assert item[0] == 1
|
|
assert item[1] == (i + 1) % 2
|
|
if i == 1:
|
|
break
|
|
|
|
|
|
@pytest.mark.parametrize("shuffle", [False, True])
|
|
def test_chunked_dataset_serialization(shuffle):
|
|
dset = WorkingChunkedDataset(list(range(5)), chunk_size=2, shuffle=shuffle, wrap=False)
|
|
|
|
assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0}
|
|
|
|
dset_iter = iter(dset)
|
|
assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0}
|
|
|
|
dset.load_state_dict(dset.state_dict(0, 0))
|
|
assert dset.state_dict(0, 0) == {"current_chunk": 0, "current_sample_in_chunk": 0}
|
|
|
|
dset_iter = iter(dset)
|
|
|
|
# throw away first few batches to alter internal state
|
|
for i in range(3):
|
|
next(dset_iter)
|
|
|
|
curr_state = dset.state_dict(3, 0)
|
|
|
|
original = [next(dset_iter) for _ in range(5)]
|
|
|
|
dset.load_state_dict(curr_state)
|
|
dset_iter = iter(dset)
|
|
after_loading = [next(dset_iter) for _ in range(5)]
|
|
|
|
# this isn't because we always skip to beginning of next chunk when loading and not already at beginning of chunk
|
|
assert original != after_loading
|
|
assert original[1:] == after_loading[:-1]
|
|
|
|
# this actually puts us already on beginning of a chunk, but we'll forward to beginning of next chunk,
|
|
# otherwise we'd two times resume from same checkpoint and assert equal behavior
|
|
dset.load_state_dict(curr_state)
|
|
_ = [next(dset_iter) for _ in range(2)]
|
|
|
|
new_curr_state = dset.state_dict(6, 0)
|
|
|
|
new_original = [next(dset_iter) for _ in range(3)]
|
|
|
|
dset.load_state_dict(new_curr_state)
|
|
new_after_loading = [next(dset_iter) for _ in range(3)]
|
|
|
|
# this is equal since we exactly stopped at beginning of new chunk
|
|
assert new_original == new_after_loading
|
|
|
|
|
|
class ChunkedTestDatasetDistributed(WorkingChunkedDataset):
|
|
def _apply_sharding(self):
|
|
super()._apply_sharding()
|
|
|
|
assert len(self._local_chunks) == self.expected_num_chunks
|
|
|
|
for i in range(1, len(self._local_chunks)):
|
|
assert self._local_chunks[i]._chunk_data - self._local_chunks[i - 1]._chunk_data == self.expected_step_width
|
|
|
|
|
|
def sharding_test(fabric: lightning.Fabric, num_workers):
|
|
dset = ChunkedTestDatasetDistributed(list(range(50)), 2, shuffle=False, wrap=False)
|
|
|
|
num_shards = max(1, num_workers) * fabric.world_size
|
|
|
|
# num_workers = 0 still has a single worker (the main process)
|
|
expected_num_chunks = 50 // num_shards
|
|
dset.expected_num_chunks = expected_num_chunks
|
|
dset.expected_step_width = fabric.world_size * max(num_workers, 1)
|
|
|
|
num_samples_per_rank = max(num_workers, 1) * 2 * expected_num_chunks
|
|
loader = torch.utils.data.DataLoader(dset, num_workers=num_workers)
|
|
|
|
for i, _ in enumerate(loader):
|
|
fabric.barrier()
|
|
|
|
assert i == num_samples_per_rank - 1
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("num_workers", "world_size"),
|
|
[
|
|
pytest.param(0, 1),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
1,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
2,
|
|
1,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
2,
|
|
2,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
def test_sharding(num_workers, world_size):
|
|
fabric = lightning.Fabric(accelerator="cpu", devices=world_size, strategy="ddp_spawn")
|
|
fabric.launch(partial(sharding_test, num_workers=num_workers))
|
|
|
|
|
|
def sharding_resume_test(fabric: lightning.Fabric, num_workers):
|
|
chunk_size = 2
|
|
dset = WorkingChunkedDataset(list(range(100)), chunk_size, shuffle=False, wrap=False)
|
|
loader = torch.utils.data.DataLoader(dset, num_workers=num_workers, shuffle=False)
|
|
num_shards = max(1, num_workers) * fabric.world_size
|
|
|
|
for i in [23, 37, 10, 20]:
|
|
curr_index = math.ceil(i / num_shards / chunk_size) * num_shards * chunk_size
|
|
next_chunk = math.ceil(curr_index / chunk_size)
|
|
|
|
curr_state = dset.state_dict(i, num_workers=num_workers)
|
|
assert curr_state == {"current_chunk": next_chunk, "current_sample_in_chunk": 0}
|
|
|
|
dset.load_state_dict(curr_state)
|
|
loader = torch.utils.data.DataLoader(dset, num_workers=num_workers, shuffle=False)
|
|
|
|
# calculate starting chunks
|
|
# next_chunk + fabric.global_rank * max(1,num_workers) determines the base offset for each rank
|
|
# i % chunk_size makes sure that workers are alternating
|
|
# e.g. w0 returns first element of first chunk then w1 returns first element of second chunk then w0 returns
|
|
# second element of first chunk etc.
|
|
# i // num_shards * num_shards progresses to next chunks
|
|
curr_worker_chunk = {
|
|
i: next_chunk
|
|
+ fabric.global_rank * max(1, num_workers)
|
|
+ i % max(1, num_workers)
|
|
+ i // (chunk_size * num_shards)
|
|
for i in range(max(1, num_workers))
|
|
}
|
|
curr_worker_chunk_elem = {i: 0 for i in range(max(1, num_workers))}
|
|
|
|
for i, batch in enumerate(loader):
|
|
curr_worker = i % max(1, num_workers)
|
|
assert batch[0] == curr_worker_chunk[curr_worker]
|
|
assert batch[1] == curr_worker_chunk_elem[curr_worker]
|
|
|
|
curr_worker_chunk_elem[curr_worker] += 1
|
|
|
|
if curr_worker_chunk_elem[curr_worker] == chunk_size:
|
|
curr_worker_chunk[curr_worker] += num_shards
|
|
curr_worker_chunk_elem[curr_worker] = 0
|
|
fabric.barrier()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("num_workers", "world_size"),
|
|
[
|
|
pytest.param(0, 1),
|
|
pytest.param(
|
|
0,
|
|
2,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
1,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
1,
|
|
2,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
2,
|
|
1,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
2,
|
|
2,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
def test_chunked_dataset_sharded_state_dict_resume(num_workers, world_size):
|
|
fabric = lightning.Fabric(accelerator="cpu", devices=world_size, strategy="ddp_spawn")
|
|
fabric.launch(partial(sharding_resume_test, num_workers=num_workers))
|
|
|
|
|
|
@pytest.mark.parametrize("chunk_size", [20, 30, 40])
|
|
@pytest.mark.parametrize("shuffle", [False, True])
|
|
@pytest.mark.parametrize("shuffle_seed", [None, 123])
|
|
@pytest.mark.parametrize("delayed_start", [False, True])
|
|
def test_chunk(chunk_size, shuffle, shuffle_seed, delayed_start):
|
|
data = list(range(chunk_size))
|
|
delayed_start_index = int(delayed_start) * (chunk_size - 10)
|
|
chunk = _Chunk(data, chunk_size=chunk_size, start_index=delayed_start_index)
|
|
linear_permutation = tuple(range(chunk_size))
|
|
assert chunk.index_permutations == linear_permutation
|
|
|
|
for i, index in enumerate(chunk):
|
|
assert index == delayed_start_index + i
|
|
|
|
assert chunk.chunk_size == chunk_size
|
|
|
|
if shuffle:
|
|
generator = torch.Generator().manual_seed(shuffle_seed) if shuffle_seed else None
|
|
|
|
chunk = chunk.shuffle(generator=generator)
|
|
|
|
old_permutation = chunk.index_permutations
|
|
assert old_permutation != linear_permutation
|
|
|
|
new_perm = []
|
|
|
|
for i, index in enumerate(chunk):
|
|
new_perm.append(index)
|
|
|
|
assert tuple(new_perm) == tuple([old_permutation[k] for k in range(delayed_start_index, chunk_size)])
|
|
assert len(new_perm) == chunk_size - delayed_start_index
|
|
|
|
if shuffle_seed:
|
|
chunk = chunk.shuffle(generator=generator.manual_seed(shuffle_seed))
|
|
assert chunk.index_permutations == old_permutation
|
|
|
|
assert chunk.chunk_size == chunk_size
|
|
|
|
|
|
class MyDataset(_StatefulIterableDataset):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
self.samples = list(range(length))
|
|
self.curr_iter = 0
|
|
|
|
def __iter__(self):
|
|
for sample in self.samples[self.curr_iter :]:
|
|
yield sample
|
|
self.curr_iter += 1
|
|
|
|
def state_dict(self, returned_samples, num_workers):
|
|
return {"curr_iter": returned_samples, "num_workers": num_workers}
|
|
|
|
def load_state_dict(self, state_dict):
|
|
self.curr_iter = state_dict.pop("curr_iter")
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3])
|
|
@pytest.mark.parametrize(
|
|
"num_workers",
|
|
[
|
|
pytest.param(0),
|
|
pytest.param(
|
|
1,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
pytest.param(
|
|
2,
|
|
marks=pytest.mark.skipif(
|
|
sys.platform != "linux", reason="multiprocessing on other platforms takes forever"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("prefetch_factor", [1, 2, 3])
|
|
@pytest.mark.parametrize("length", [100, 101])
|
|
@pytest.mark.parametrize("num_batches", [1, 2, 7])
|
|
def test_resumable_loader(batch_size, num_workers, prefetch_factor, length, num_batches):
|
|
dset = MyDataset(length)
|
|
loader = DataLoader(
|
|
dset,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
prefetch_factor=prefetch_factor if num_workers > 0 else None,
|
|
)
|
|
|
|
loader_iter = iter(loader)
|
|
for i, batch in enumerate(loader_iter):
|
|
assert loader._get_batch_size(batch) == batch_size
|
|
if i == num_batches - 1:
|
|
break
|
|
|
|
state_dict = loader.state_dict()
|
|
assert state_dict["returned_samples"] == batch_size * num_batches
|
|
assert state_dict["dataset"] == {
|
|
"curr_iter": batch_size * num_batches,
|
|
"num_workers": num_workers,
|
|
}
|
|
|
|
state_dict["returned_samples"] += 1
|
|
state_dict["dataset"]["curr_iter"] += 1
|
|
loader.load_state_dict(state_dict)
|
|
assert loader.returned_samples == batch_size * num_batches + 1
|
|
assert loader.dataset.curr_iter == batch_size * num_batches + 1
|
|
|
|
|
|
def test_state_dict_error():
|
|
loader = DataLoader([1, 2, 3])
|
|
with pytest.raises(
|
|
TypeError,
|
|
match="The dataset has no method `state_dict` that accepts `returned_samples` and `num_workers`",
|
|
):
|
|
loader.state_dict()
|
|
|
|
|
|
def test_load_state_dict_error():
|
|
loader = DataLoader([1, 2, 3])
|
|
with pytest.raises(
|
|
TypeError,
|
|
match="The dataset has no method `load_state_dict` accepting a `state_dict`",
|
|
):
|
|
loader.load_state_dict({"returned_samples": 1, "dataset": {"some_key": "some_val"}})
|