lightning/tests/tests_data/cache/test_sampler.py

113 lines
3.8 KiB
Python

from unittest import mock
import pytest
from lightning import seed_everything
from lightning.data.cache.sampler import CacheBatchSampler
@pytest.mark.parametrize(
"params",
[
(
21,
1,
[[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]],
[[7, 0, 0], [1, 1, 1], [5, 5, 5], [0, 4, 4], [8, 3, 3], [2, 2, 2], [4], [3], [6]],
),
(
11,
1,
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]],
[[1, 1, 1], [3, 3], [0, 0, 0], [2, 2, 2]],
),
(8, 1, [[0, 1], [2, 3], [4, 5, 6], [], [], [7]], [[1, 1, 2], [3], [0, 0], [2, 2]]),
(4, 1, [[0], [1], [2, 3]], [[0], [1], [2, 2]]),
(
9,
1,
[[0, 1, 2], [3, 4, 5], [6, 7, 8]],
[[0, 0, 0], [1, 1, 1], [2, 2, 2]],
),
(
19,
1,
[[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]],
[[0, 0, 0], [1, 1, 1], [5, 5, 5], [2, 2, 2], [4, 4, 4], [3, 3, 3], [6]],
),
(19, 2, [[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[0, 0, 0], [5, 5, 5], [4, 4, 4], [6]]),
],
)
def test_cache_batch_sampler(params):
seed_everything(42)
cache = mock.MagicMock()
cache.filled = False
if params[1] > 1:
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
batches = []
for batch in batch_sampler:
batches.append(batch)
assert batches == params[2], batches
batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache)
batches = []
for batch in batch_sampler:
batches.append(batch)
chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)]
else:
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
batches = []
for batch in batch_sampler:
batches.append(batch)
assert batches == params[2], batches
chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)]
cache.filled = True
cache.get_chunk_interval.return_value = chunks_interval
seed_everything(42)
batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache)
batches_1 = []
for batch in batch_sampler:
batches_1.append(batch)
def validate_batch(data, check_values):
if params[1] == 1:
assert all(b[0].chunk_indexes is not None for b in data[:3])
assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3])
assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:])
if check_values:
assert [[x.chunk_index for x in d] for d in data] == params[3]
else:
assert all(b[0].chunk_indexes is not None for b in data[:3])
assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3])
assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:])
if check_values:
assert [[x.chunk_index for x in d] for d in data] == params[3]
validate_batch(batches_1, True)
batches_2 = []
for batch in batch_sampler:
batches_2.append(batch)
validate_batch(batches_2, False)
if params[1] == 1:
assert batches_1 != batches_2
def test_batch_sampler_imagenet():
"""Validate the Imagenet dataset is valid."""
dataset_size = 1281167
world_size = 1
rank = 0
num_workers = 32
batch_size = 8
cache = mock.MagicMock()
cache.filled = False
CacheBatchSampler(dataset_size, world_size, rank, num_workers, batch_size, False, True, cache)