lightning/tests/tests_pytorch/trainer/test_supporters.py

440 lines
16 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from typing import Sequence
from unittest import mock
import pytest
import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.trainer.supporters import (
_nested_calc_num_data,
CombinedDataset,
CombinedLoader,
CombinedLoaderIterator,
CycleIterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf
def test_cycle_iterator():
"""Test the cycling function of `CycleIterator`"""
iterator = CycleIterator(range(100), 1000)
assert len(iterator) == 1000
for idx, item in enumerate(iterator):
assert item < 100
assert idx == len(iterator) - 1
def test_none_length_cycle_iterator():
"""Test the infinite cycling function of `CycleIterator`"""
iterator = CycleIterator(range(100))
assert iterator.__len__() == float("inf")
# test infinite loop
for idx, item in enumerate(iterator):
if idx == 1000:
break
assert item == 0
@pytest.mark.parametrize(
["dataset_1", "dataset_2"],
[
([list(range(10)), list(range(20))]),
([range(10), range(20)]),
([torch.randn(10, 3, 2), torch.randn(20, 5, 6)]),
([TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))]),
],
)
def test_combined_dataset(dataset_1, dataset_2):
"""Verify the length of the CombinedDataset."""
datasets = [dataset_1, dataset_2]
combined_dataset = CombinedDataset(datasets)
assert combined_dataset.max_len == 20
assert combined_dataset.min_len == len(combined_dataset) == 10
def test_combined_dataset_length_mode_error():
dset = CombinedDataset([range(10)])
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
dset._calc_num_data([range(10)], "test")
def test_combined_loader_iterator_dict_min_size():
"""Test `CombinedLoaderIterator` given mapping loaders."""
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
combined_iter = CombinedLoaderIterator(loaders)
for idx, item in enumerate(combined_iter):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1
def test_combined_loader_init_mode_error():
"""Test the ValueError when constructing `CombinedLoader`"""
with pytest.raises(MisconfigurationException, match="Invalid Mode"):
CombinedLoader([range(10)], "testtt")
def test_combined_loader_loader_type_error():
"""Test the ValueError when wrapping the loaders."""
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
CombinedLoader(None, "max_size_cycle")
def test_combined_loader_calc_length_mode_error():
"""Test the ValueError when calculating the number of batches."""
with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"):
CombinedLoader._calc_num_batches(None)
def test_combined_loader_dict_min_size():
"""Test `CombinedLoader` of mode 'min_size' given mapping loaders."""
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
combined_loader = CombinedLoader(loaders, "min_size")
assert len(combined_loader) == min(len(v) for v in loaders.values())
for idx, item in enumerate(combined_loader):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == len(combined_loader) - 1
def test_combined_loader_dict_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders."""
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
combined_loader = CombinedLoader(loaders, "max_size_cycle")
assert len(combined_loader) == max(len(v) for v in loaders.values())
for idx, item in enumerate(combined_loader):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == len(combined_loader) - 1
def test_combined_loader_sequence_min_size():
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]
combined_loader = CombinedLoader(loaders, "min_size")
assert len(combined_loader) == min(len(v) for v in loaders)
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2
assert idx == len(combined_loader) - 1
class TestIterableDataset(IterableDataset):
def __init__(self, size: int = 10):
self.size = size
def __iter__(self):
self.sampler = SequentialSampler(range(self.size))
self.sampler_iter = iter(self.sampler)
return self
def __next__(self):
return next(self.sampler_iter)
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"])
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
if use_multiple_dataloaders:
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2),
]
else:
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
]
combined_loader = CombinedLoader(loaders, mode)
has_break = False
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2 if use_multiple_dataloaders else 1
if not use_multiple_dataloaders and idx == 4:
has_break = True
break
if mode == "max_size_cycle":
assert combined_loader.loaders[0].state.done == (not has_break)
expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5
assert (expected - 1) == idx, (mode, use_multiple_dataloaders)
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
def test_combined_loader_sequence_with_map_and_iterable(lengths):
class MyIterableDataset(IterableDataset):
def __init__(self, size: int = 10):
self.size = size
def __iter__(self):
self.sampler = SequentialSampler(range(self.size))
self.iter_sampler = iter(self.sampler)
return self
def __next__(self):
return next(self.iter_sampler)
class MyMapDataset(Dataset):
def __init__(self, size: int = 10):
self.size = size
def __getitem__(self, index):
return index
def __len__(self):
return self.size
x, y = lengths
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
counter = 0
for _ in dataloader:
counter += 1
assert counter == max(x, y)
def test_combined_loader_sequence_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders."""
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]
combined_loader = CombinedLoader(loaders, "max_size_cycle")
assert len(combined_loader) == max(len(v) for v in loaders)
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2
assert idx == len(combined_loader) - 1
@pytest.mark.parametrize(
["input_data", "compute_func", "expected_length"],
[
([*range(10), list(range(1, 20))], min, 0),
([*range(10), list(range(1, 20))], max, 19),
([*range(10), {str(i): i for i in range(1, 20)}], min, 0),
([*range(10), {str(i): i for i in range(1, 20)}], max, 19),
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, min, 0),
({**{str(i): i for i in range(10)}, "nested": {str(i): i for i in range(1, 20)}}, max, 19),
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, min, 0),
({**{str(i): i for i in range(10)}, "nested": list(range(20))}, max, 19),
],
)
def test_nested_calc_num_data(input_data, compute_func, expected_length):
calculated_length = _nested_calc_num_data(input_data, compute_func)
assert calculated_length == expected_length
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
def test_combined_data_loader_validation_test(mps_count_0, cuda_count_2, replace_sampler_ddp):
"""This test makes sure distributed sampler has been properly injected in dataloaders when using
CombinedLoader."""
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
class CustomSampler(RandomSampler):
def __init__(self, data_source, name) -> None:
super().__init__(data_source)
self.name = name
dataset = CustomDataset(range(10))
dataloader = CombinedLoader(
{
"a": DataLoader(CustomDataset(range(10))),
"b": DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")),
"c": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
"d": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
}
)
trainer = Trainer(replace_sampler_ddp=replace_sampler_ddp, strategy="ddp", accelerator="gpu", devices=2)
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=True)
count = 0
def _assert_distributed_sampler(v):
nonlocal count
count += 1
if replace_sampler_ddp:
assert isinstance(v, DistributedSampler)
else:
assert isinstance(v, (SequentialSampler, CustomSampler))
apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler)
assert count == 6
def _assert_dataset(loader):
d = loader.dataset
assert isinstance(d, CustomDataset)
apply_to_collection(dataloader.loaders, DataLoader, _assert_dataset)
@pytest.mark.parametrize("accelerator", ["cpu", pytest.param("gpu", marks=RunIf(min_cuda_gpus=2))])
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_sampler_ddp):
"""This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
with ddp and `max_size_cycle` mode."""
trainer = Trainer(strategy="ddp", accelerator=accelerator, devices=2, replace_sampler_ddp=replace_sampler_ddp)
dataloader = CombinedLoader(
{"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
)
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == 4 if replace_sampler_ddp else 8
for a_length in [6, 8, 10]:
dataloader = CombinedLoader(
{
"a": DataLoader(range(a_length), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)
length = max(a_length, 8)
assert len(dataloader) == length
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader) == length // 2 if replace_sampler_ddp else length
if replace_sampler_ddp:
last_batch = list(dataloader)[-1]
if a_length == 6:
assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
elif a_length == 8:
assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
elif a_length == 10:
assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}
class InfiniteDataset(IterableDataset):
def __iter__(self):
while True:
yield 1
dataloader = CombinedLoader(
{
"a": DataLoader(InfiniteDataset(), batch_size=1),
"b": DataLoader(range(8), batch_size=1),
},
mode="max_size_cycle",
)
assert get_len(dataloader) == float("inf")
assert len(dataloader.loaders["b"].loader) == 8
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader.loaders["b"].loader) == 4 if replace_sampler_ddp else 8
assert get_len(dataloader) == float("inf")
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
@pytest.mark.parametrize("is_min_size_mode", [False, True])
@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_combined_dataloader_for_training_with_ddp(
replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool
):
"""When providing a CombinedLoader as the training data, it should be correctly receive the distributed
samplers."""
mode = "min_size" if is_min_size_mode else "max_size_cycle"
dim = 3
n1 = 8
n2 = 6
dataloader = {
"a": DataLoader(RandomDataset(dim, n1), batch_size=1),
"b": DataLoader(RandomDataset(dim, n2), batch_size=1),
}
if use_combined_loader:
dataloader = CombinedLoader(dataloader, mode=mode)
model = BoringModel()
trainer = Trainer(
strategy="ddp",
accelerator="auto",
devices="auto",
replace_sampler_ddp=replace_sampler_ddp,
multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode,
)
trainer._data_connector.attach_data(
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
)
expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2)
expected_length_after_ddp = (
math.ceil(expected_length_before_ddp / trainer.num_devices)
if replace_sampler_ddp
else expected_length_before_ddp
)
trainer.reset_train_dataloader(model=model)
assert trainer.train_dataloader is not None
assert isinstance(trainer.train_dataloader, CombinedLoader)
assert trainer.train_dataloader.mode == mode
assert trainer.num_training_batches == expected_length_after_ddp