659 lines
24 KiB
Python
659 lines
24 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
import pickle
|
|
from typing import Any, NamedTuple, Sequence, get_args
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
import torch
|
|
from lightning.fabric.utilities.types import _Stateful
|
|
from lightning.pytorch import Trainer
|
|
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
|
from lightning.pytorch.utilities.combined_loader import (
|
|
_LITERAL_SUPPORTED_MODES,
|
|
_SUPPORTED_MODES,
|
|
CombinedLoader,
|
|
_MaxSize,
|
|
_MaxSizeCycle,
|
|
_MinSize,
|
|
_Sequential,
|
|
)
|
|
from torch import Tensor
|
|
from torch.utils._pytree import tree_flatten
|
|
from torch.utils.data import DataLoader, TensorDataset
|
|
from torch.utils.data.dataset import Dataset, IterableDataset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("dataset_1", "dataset_2"),
|
|
[
|
|
(list(range(10)), list(range(20))),
|
|
(range(10), range(20)),
|
|
(torch.randn(10, 3, 2), torch.randn(20, 5, 6)),
|
|
(TensorDataset(torch.randn(10, 3, 2)), TensorDataset(torch.randn(20, 5, 6))),
|
|
],
|
|
)
|
|
def test_combined_dataset(dataset_1, dataset_2):
|
|
datasets = [DataLoader(dataset_1), DataLoader(dataset_2)]
|
|
combined_loader = CombinedLoader(datasets, "max_size_cycle")
|
|
assert combined_loader._dataset_length() == 20
|
|
|
|
|
|
def test_combined_dataset_no_length():
|
|
class Foo:
|
|
# map-style
|
|
def __len__(self):
|
|
return 5
|
|
|
|
class Bar:
|
|
# iterable style
|
|
...
|
|
|
|
class Baz:
|
|
# None length
|
|
def __len__(self):
|
|
pass
|
|
|
|
cl = CombinedLoader([DataLoader(Foo()), DataLoader(Bar()), DataLoader(Baz())])
|
|
assert cl._dataset_length() == 5
|
|
|
|
cl = CombinedLoader(DataLoader(Bar()))
|
|
with pytest.raises(NotImplementedError, match="All datasets are iterable-style"):
|
|
cl._dataset_length()
|
|
|
|
|
|
def test_combined_loader_length_must_call_iter_first():
|
|
loader = CombinedLoader([1, 2, 3])
|
|
with pytest.raises(RuntimeError, match="Please call `iter.*` first"):
|
|
len(loader)
|
|
|
|
|
|
def test_combined_loader_modes_for_dict():
|
|
"""Test `CombinedLoaderIterator` given mapping iterables."""
|
|
iterables = {
|
|
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
|
|
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
|
|
}
|
|
lengths = [len(v) for v in iterables.values()]
|
|
|
|
# min_size with dict
|
|
min_len = min(lengths)
|
|
combined_loader = CombinedLoader(iterables, "min_size")
|
|
iter(combined_loader)
|
|
assert combined_loader._iterator is not None
|
|
assert len(combined_loader) == min_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MinSize)
|
|
assert isinstance(item, dict)
|
|
assert list(item) == ["a", "b"]
|
|
assert idx == min_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# max_size_cycle with dict
|
|
max_len = max(lengths)
|
|
combined_loader = CombinedLoader(iterables, "max_size_cycle")
|
|
iter(combined_loader)
|
|
assert combined_loader._iterator is not None
|
|
assert len(combined_loader) == max_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
|
|
assert isinstance(item, dict)
|
|
assert list(item) == ["a", "b"]
|
|
assert idx == max_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# max_size with dict
|
|
combined_loader = CombinedLoader(iterables, "max_size")
|
|
iter(combined_loader)
|
|
assert len(combined_loader) == max_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MaxSize)
|
|
assert isinstance(item, dict)
|
|
assert list(item) == ["a", "b"]
|
|
are_nones = [x is None for x in item.values()]
|
|
should_be_nones = [idx >= length for length in lengths]
|
|
assert are_nones == should_be_nones
|
|
assert idx == max_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# sequential with dict
|
|
sum_len = sum(lengths)
|
|
combined_loader = CombinedLoader(iterables, "sequential")
|
|
iter(combined_loader)
|
|
assert combined_loader._iterator is not None
|
|
assert len(combined_loader) == sum_len
|
|
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
|
|
assert isinstance(combined_loader._iterator, _Sequential)
|
|
assert isinstance(batch_idx, int)
|
|
assert isinstance(item, Tensor)
|
|
assert idx == lengths[-1] - 1
|
|
assert total_idx == sum_len - 1
|
|
assert total_idx == len(combined_loader) - 1
|
|
assert dataloader_idx == len(iterables) - 1
|
|
|
|
|
|
def test_combined_loader_modes_for_list():
|
|
"""Test `CombinedLoaderIterator` given list of iterables."""
|
|
iterables = [
|
|
torch.utils.data.DataLoader(range(10), batch_size=4),
|
|
torch.utils.data.DataLoader(range(20), batch_size=5),
|
|
]
|
|
lengths = [len(v) for v in iterables]
|
|
|
|
# min_size with list
|
|
min_len = min(lengths)
|
|
combined_loader = CombinedLoader(iterables, "min_size")
|
|
iter(combined_loader)
|
|
assert len(combined_loader) == min_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MinSize)
|
|
assert isinstance(item, list)
|
|
assert len(item) == 2
|
|
assert idx == min_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# max_size_cycle with list
|
|
max_len = max(lengths)
|
|
combined_loader = CombinedLoader(iterables, "max_size_cycle")
|
|
iter(combined_loader)
|
|
assert len(combined_loader) == max_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
|
|
assert isinstance(item, list)
|
|
assert len(item) == 2
|
|
assert idx == max_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# max_size with list
|
|
combined_loader = CombinedLoader(iterables, "max_size")
|
|
iter(combined_loader)
|
|
assert len(combined_loader) == max_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MaxSize)
|
|
assert isinstance(item, list)
|
|
assert len(item) == 2
|
|
are_nones = [x is None for x in item]
|
|
should_be_nones = [idx >= length for length in lengths]
|
|
assert are_nones == should_be_nones
|
|
|
|
assert idx == max_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# sequential with list
|
|
sum_len = sum(lengths)
|
|
combined_loader = CombinedLoader(iterables, "sequential")
|
|
iter(combined_loader)
|
|
assert combined_loader._iterator is not None
|
|
assert len(combined_loader) == sum_len
|
|
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
|
|
assert isinstance(combined_loader._iterator, _Sequential)
|
|
assert isinstance(batch_idx, int)
|
|
assert isinstance(item, Tensor)
|
|
assert idx == lengths[-1] - 1
|
|
assert total_idx == sum_len - 1
|
|
assert total_idx == len(combined_loader) - 1
|
|
assert dataloader_idx == len(iterables) - 1
|
|
|
|
|
|
def test_combined_loader_modes_for_namedtuple():
|
|
"""Test `CombinedLoaderIterator` given a namedtuple of iterables."""
|
|
|
|
class IterablesNamedTuple(NamedTuple):
|
|
a: Any
|
|
b: Any
|
|
|
|
iterables = IterablesNamedTuple(
|
|
a=torch.utils.data.DataLoader(range(10), batch_size=4), b=torch.utils.data.DataLoader(range(20), batch_size=5)
|
|
)
|
|
lengths = [len(v) for v in iterables]
|
|
|
|
# min_size with namedtuple
|
|
min_len = min(lengths)
|
|
combined_loader = CombinedLoader(iterables, "min_size")
|
|
iter(combined_loader)
|
|
assert len(combined_loader) == min_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MinSize)
|
|
assert isinstance(item, IterablesNamedTuple)
|
|
assert idx == min_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# max_size_cycle with namedtuple
|
|
max_len = max(lengths)
|
|
combined_loader = CombinedLoader(iterables, "max_size_cycle")
|
|
iter(combined_loader)
|
|
assert len(combined_loader) == max_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
|
|
assert isinstance(item, IterablesNamedTuple)
|
|
assert idx == max_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# max_size with namedtuple
|
|
combined_loader = CombinedLoader(iterables, "max_size")
|
|
iter(combined_loader)
|
|
assert len(combined_loader) == max_len
|
|
for item, idx, _ in combined_loader:
|
|
assert isinstance(combined_loader._iterator, _MaxSize)
|
|
assert isinstance(item, IterablesNamedTuple)
|
|
are_nones = [x is None for x in item]
|
|
should_be_nones = [idx >= length for length in lengths]
|
|
assert are_nones == should_be_nones
|
|
assert idx == max_len - 1
|
|
assert idx == len(combined_loader) - 1
|
|
|
|
# sequential with namedtuple
|
|
sum_len = sum(lengths)
|
|
combined_loader = CombinedLoader(iterables, "sequential")
|
|
iter(combined_loader)
|
|
assert combined_loader._iterator is not None
|
|
assert len(combined_loader) == sum_len
|
|
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
|
|
assert isinstance(combined_loader._iterator, _Sequential)
|
|
assert isinstance(batch_idx, int)
|
|
assert isinstance(item, Tensor)
|
|
assert idx == lengths[-1] - 1
|
|
assert total_idx == sum_len - 1
|
|
assert total_idx == len(combined_loader) - 1
|
|
assert dataloader_idx == len(iterables) - 1
|
|
|
|
|
|
def test_combined_loader_raises():
|
|
with pytest.raises(ValueError, match="Unsupported mode 'testtt'"):
|
|
CombinedLoader([range(10)], "testtt")
|
|
|
|
|
|
class TestIterableDataset(IterableDataset):
|
|
def __init__(self, size: int = 10):
|
|
self.size = size
|
|
|
|
def __iter__(self):
|
|
self.sampler = SequentialSampler(range(self.size))
|
|
self.sampler_iter = iter(self.sampler)
|
|
return self
|
|
|
|
def __next__(self):
|
|
return next(self.sampler_iter)
|
|
|
|
|
|
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
|
|
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
|
|
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
|
|
"""Test `CombinedLoader` of mode 'min_size' given sequence iterables."""
|
|
if use_multiple_dataloaders:
|
|
loaders = [
|
|
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
|
|
torch.utils.data.DataLoader(TestIterableDataset(20), batch_size=2),
|
|
]
|
|
else:
|
|
loaders = [
|
|
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
|
|
]
|
|
combined_loader = CombinedLoader(loaders, mode)
|
|
|
|
has_break = False
|
|
for idx, item in enumerate(combined_loader):
|
|
assert isinstance(item, Sequence)
|
|
if not use_multiple_dataloaders and idx == 4:
|
|
has_break = True
|
|
break
|
|
|
|
if mode == "max_size_cycle":
|
|
assert all(combined_loader._iterator._consumed) == (not has_break)
|
|
expected = 5
|
|
if use_multiple_dataloaders:
|
|
if mode in ["max_size_cycle", "max_size"]:
|
|
expected = 10
|
|
elif mode == "sequential":
|
|
expected = 15
|
|
assert idx == expected - 1
|
|
|
|
|
|
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
|
|
def test_combined_loader_simultaneous_workers(mode):
|
|
"""Test `CombinedLoader` to check how it initializes dataloader workers."""
|
|
|
|
class TestDataLoader(DataLoader):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.workers_active = False
|
|
|
|
def _get_iterator(self):
|
|
self.workers_active = True
|
|
return super()._get_iterator()
|
|
|
|
def _shutdown_workers(self):
|
|
self.workers_active = False
|
|
super()._shutdown_workers()
|
|
|
|
loaders = [
|
|
TestDataLoader(range(10), batch_size=2, num_workers=0),
|
|
TestDataLoader(range(20), batch_size=2, num_workers=0),
|
|
]
|
|
combined_loader = CombinedLoader(loaders, mode)
|
|
# Start the dataloader
|
|
_ = iter(combined_loader)
|
|
|
|
workers_active = []
|
|
for loader in loaders:
|
|
workers_active.append(loader.workers_active)
|
|
|
|
# Sequential only starts the first dataloader, other modes start both
|
|
expected = [True, False] if mode == "sequential" else [True, True]
|
|
assert workers_active == expected
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("limits", "expected"),
|
|
[
|
|
(None, [("a", 0, 0), ("b", 1, 0), ("c", 2, 0), ("d", 0, 1), ("e", 1, 1)]),
|
|
([1, 0], [("a", 0, 0)]),
|
|
([0, float("inf")], [("d", 0, 1), ("e", 1, 1)]),
|
|
([1, 1], [("a", 0, 0), ("d", 0, 1)]),
|
|
],
|
|
)
|
|
def test_sequential_mode_limits(limits, expected):
|
|
iterable1 = ["a", "b", "c"]
|
|
iterable2 = ["d", "e"]
|
|
iterator = _Sequential([iterable1, iterable2], limits)
|
|
assert list(iterator) == expected
|
|
|
|
|
|
@pytest.mark.parametrize("iterator_cls", [_Sequential, _MinSize, _MaxSize, _MaxSizeCycle])
|
|
def test_iterator_mode_limits_raises(iterator_cls):
|
|
with pytest.raises(ValueError, match=r"number of limits \(0\) and number of iterables \(2\)"):
|
|
iterator_cls([0, 1], [])
|
|
|
|
|
|
def test_combined_loader_flattened_setter():
|
|
iterables = [[0], [[1], [[2]]]]
|
|
combined_loader = CombinedLoader(iterables)
|
|
with pytest.raises(ValueError, match=r"Mismatch in flattened length \(1\) and existing length \(3\)"):
|
|
combined_loader.flattened = [2]
|
|
assert combined_loader.flattened == [[0], [1], [2]]
|
|
combined_loader.flattened = [[3], [2], [1]]
|
|
assert combined_loader.iterables == [[3], [[2], [[1]]]]
|
|
|
|
|
|
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
|
|
def test_combined_loader_sequence_with_map_and_iterable(lengths):
|
|
class MyIterableDataset(IterableDataset):
|
|
def __init__(self, size: int = 10):
|
|
self.size = size
|
|
|
|
def __iter__(self):
|
|
self.sampler = SequentialSampler(range(self.size))
|
|
self.iter_sampler = iter(self.sampler)
|
|
return self
|
|
|
|
def __next__(self):
|
|
return next(self.iter_sampler)
|
|
|
|
class MyMapDataset(Dataset):
|
|
def __init__(self, size: int = 10):
|
|
self.size = size
|
|
|
|
def __getitem__(self, index):
|
|
return index
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
x, y = lengths
|
|
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
|
|
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
|
|
seen = sum(1 for _ in dataloader)
|
|
assert seen == max(x, y)
|
|
|
|
|
|
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
|
|
def test_combined_data_loader_validation_test(use_distributed_sampler):
|
|
"""This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader."""
|
|
|
|
class CustomDataset(Dataset):
|
|
def __init__(self, data):
|
|
self.data = data
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
return self.data[index]
|
|
|
|
class CustomSampler(RandomSampler):
|
|
def __init__(self, data_source, name) -> None:
|
|
super().__init__(data_source)
|
|
self.name = name
|
|
|
|
dataset = CustomDataset(range(10))
|
|
combined_loader = CombinedLoader({
|
|
"a": DataLoader(CustomDataset(range(10))),
|
|
"b": DataLoader(dataset, sampler=CustomSampler(dataset, "custom_sampler")),
|
|
"c": {"c": DataLoader(CustomDataset(range(10))), "d": DataLoader(CustomDataset(range(10)))},
|
|
"d": [DataLoader(CustomDataset(range(10))), DataLoader(CustomDataset(range(10)))],
|
|
})
|
|
model = BoringModel()
|
|
trainer = Trainer(use_distributed_sampler=use_distributed_sampler, strategy="ddp", accelerator="cpu", devices=2)
|
|
trainer.strategy.connect(model)
|
|
trainer._data_connector.attach_data(model, train_dataloaders=combined_loader)
|
|
trainer.fit_loop.setup_data()
|
|
|
|
samplers_flattened = tree_flatten(combined_loader.sampler)[0]
|
|
assert len(samplers_flattened) == 6
|
|
if use_distributed_sampler:
|
|
assert all(isinstance(s, DistributedSampler) for s in samplers_flattened)
|
|
else:
|
|
assert all(isinstance(s, (SequentialSampler, CustomSampler)) for s in samplers_flattened)
|
|
|
|
datasets_flattened = [dl.dataset for dl in combined_loader.flattened]
|
|
assert len(datasets_flattened) == 6
|
|
assert all(isinstance(ds, CustomDataset) for ds in datasets_flattened)
|
|
|
|
|
|
@pytest.mark.parametrize("accelerator", ["cpu", pytest.param("gpu", marks=RunIf(min_cuda_gpus=2))])
|
|
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
|
|
def test_combined_data_loader_with_max_size_cycle_and_ddp(monkeypatch, accelerator, use_distributed_sampler):
|
|
"""This test makes sure distributed sampler has been properly injected in dataloaders when using CombinedLoader
|
|
with ddp and `max_size_cycle` mode."""
|
|
trainer = Trainer(
|
|
strategy="ddp", accelerator=accelerator, devices=2, use_distributed_sampler=use_distributed_sampler
|
|
)
|
|
model = BoringModel()
|
|
|
|
combined_loader = CombinedLoader(
|
|
{"a": DataLoader(RandomDataset(32, 8), batch_size=1), "b": DataLoader(RandomDataset(32, 8), batch_size=1)},
|
|
)
|
|
trainer.strategy.connect(model)
|
|
trainer._data_connector.attach_data(model, train_dataloaders=combined_loader)
|
|
trainer.fit_loop.setup_data()
|
|
|
|
assert len(combined_loader) == 4 if use_distributed_sampler else 8
|
|
|
|
for a_length in [6, 8, 10]:
|
|
combined_loader = CombinedLoader(
|
|
{
|
|
"a": DataLoader(range(a_length), batch_size=1),
|
|
"b": DataLoader(range(8), batch_size=1),
|
|
},
|
|
mode="max_size_cycle",
|
|
)
|
|
iter(combined_loader)
|
|
length = max(a_length, 8)
|
|
assert len(combined_loader) == length
|
|
|
|
trainer._data_connector.attach_data(model, train_dataloaders=combined_loader)
|
|
|
|
original_process_dataloader = trainer._data_connector._prepare_dataloader
|
|
|
|
def non_shuffle_process_dataloader(dl, shuffle, mode):
|
|
# avoid shuffling
|
|
return original_process_dataloader(dl, False, mode)
|
|
|
|
monkeypatch.setattr(trainer._data_connector, "_prepare_dataloader", non_shuffle_process_dataloader)
|
|
trainer.fit_loop.setup_data()
|
|
monkeypatch.undo()
|
|
|
|
assert len(combined_loader) == length // 2 if use_distributed_sampler else length
|
|
if use_distributed_sampler:
|
|
last_batch = list(combined_loader)[-1][0]
|
|
if a_length == 6:
|
|
assert last_batch == {"a": torch.tensor([0]), "b": torch.tensor([6])}
|
|
elif a_length == 8:
|
|
assert last_batch == {"a": torch.tensor([6]), "b": torch.tensor([6])}
|
|
elif a_length == 10:
|
|
assert last_batch == {"a": torch.tensor([8]), "b": torch.tensor([0])}
|
|
|
|
class InfiniteDataset(IterableDataset):
|
|
def __iter__(self):
|
|
while True:
|
|
yield 1
|
|
|
|
combined_loader = CombinedLoader(
|
|
{
|
|
"a": DataLoader(InfiniteDataset(), batch_size=1),
|
|
"b": DataLoader(range(8), batch_size=1),
|
|
},
|
|
mode="max_size_cycle",
|
|
)
|
|
assert len(combined_loader.iterables["b"]) == 8
|
|
|
|
trainer._data_connector.attach_data(model, train_dataloaders=combined_loader)
|
|
trainer.fit_loop.setup_data()
|
|
|
|
assert len(combined_loader.iterables["b"]) == 4 if use_distributed_sampler else 8
|
|
|
|
|
|
@pytest.mark.parametrize("use_distributed_sampler", [False, True])
|
|
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
|
|
def test_combined_dataloader_for_training_with_ddp(use_distributed_sampler, mode, mps_count_0):
|
|
"""When providing a CombinedLoader as the training data, it should be correctly receive the distributed
|
|
samplers."""
|
|
dim = 3
|
|
n1 = 8
|
|
n2 = 6
|
|
dataloader = {
|
|
"a": DataLoader(RandomDataset(dim, n1), batch_size=1),
|
|
"b": DataLoader(RandomDataset(dim, n2), batch_size=1),
|
|
}
|
|
if mode != "max_size_cycle":
|
|
dataloader = CombinedLoader(dataloader, mode=mode)
|
|
model = BoringModel()
|
|
trainer = Trainer(
|
|
strategy="ddp",
|
|
accelerator="auto",
|
|
devices="auto",
|
|
use_distributed_sampler=use_distributed_sampler,
|
|
)
|
|
trainer.strategy.connect(model)
|
|
trainer._data_connector.attach_data(model=model, train_dataloaders=dataloader)
|
|
fn = _SUPPORTED_MODES[mode]["fn"]
|
|
expected_length_before_ddp = fn([n1, n2])
|
|
expected_length_after_ddp = (
|
|
math.ceil(expected_length_before_ddp / trainer.num_devices)
|
|
if use_distributed_sampler
|
|
else expected_length_before_ddp
|
|
)
|
|
trainer.fit_loop.setup_data()
|
|
assert trainer.train_dataloader is not None
|
|
assert isinstance(trainer.fit_loop._combined_loader, CombinedLoader)
|
|
assert trainer.fit_loop._combined_loader._mode == mode
|
|
assert trainer.num_training_batches == expected_length_after_ddp
|
|
|
|
|
|
def test_supported_modes():
|
|
assert set(_SUPPORTED_MODES) == set(get_args(_LITERAL_SUPPORTED_MODES))
|
|
|
|
|
|
def test_combined_loader_can_be_pickled():
|
|
dataloader = DataLoader([0, 1, 2, 3])
|
|
|
|
# sanity check that and error would be raised. if this ever changes, `_ModeIterator.__getstate__` should be updated
|
|
iterator = iter(dataloader)
|
|
with pytest.raises(NotImplementedError, match="cannot be pickled"):
|
|
pickle.dumps(iterator)
|
|
|
|
numbers = list(range(10))
|
|
cl = CombinedLoader([dataloader, numbers])
|
|
iter(cl)
|
|
|
|
iterator = cl._iterator
|
|
assert iterator.__getstate__() == {
|
|
"iterables": [dataloader, numbers],
|
|
"iterators": [None, iterator.iterators[1]],
|
|
"limits": None,
|
|
"_idx": 0,
|
|
}
|
|
|
|
# no error
|
|
pickle.dumps(cl)
|
|
|
|
|
|
def test_state_dicts():
|
|
state1, state2, state3 = Mock(), Mock(), Mock()
|
|
stateful1 = Mock(spec=_Stateful, state_dict=Mock(return_value=state1))
|
|
stateful2 = Mock(spec=_Stateful, state_dict=Mock(return_value=state2))
|
|
stateful3 = Mock(spec=_Stateful, state_dict=Mock(return_value=state3))
|
|
|
|
cl = CombinedLoader([])
|
|
assert cl._state_dicts() == []
|
|
cl = CombinedLoader([range(2)])
|
|
assert cl._state_dicts() == []
|
|
cl = CombinedLoader([stateful1])
|
|
assert cl._state_dicts() == [state1]
|
|
cl = CombinedLoader([range(2), stateful1])
|
|
assert cl._state_dicts() == [state1]
|
|
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
|
|
assert cl._state_dicts() == [state1, state2]
|
|
cl = CombinedLoader({"a": [range(2), stateful1], "b": [stateful2], "c": stateful3})
|
|
assert cl._state_dicts() == [state1, state2, state3]
|
|
|
|
|
|
def test_load_state_dicts():
|
|
stateful1 = Mock(spec=_Stateful)
|
|
stateful2 = Mock(spec=_Stateful)
|
|
state1 = Mock()
|
|
state2 = Mock()
|
|
|
|
# 0 stateful loaders, 1 state to load
|
|
cl = CombinedLoader([range(2), range(3)])
|
|
with pytest.raises(RuntimeError, match="has 0 stateful loaders, but found 1 states"):
|
|
cl._load_state_dicts([{"state": 0}])
|
|
|
|
# 1 stateful loader, 0 states to load
|
|
cl = CombinedLoader([stateful1, range(3)])
|
|
cl._load_state_dicts([])
|
|
stateful1.load_state_dict.assert_not_called()
|
|
|
|
# 1 stateful loader, 1 state to load
|
|
cl = CombinedLoader([range(2), stateful1, range(3)])
|
|
cl._load_state_dicts([state1])
|
|
stateful1.load_state_dict.assert_called_with(state1)
|
|
stateful1.reset_mock()
|
|
|
|
# 1 stateful loader, 2 states to load
|
|
cl = CombinedLoader([range(2), stateful1, range(3)])
|
|
with pytest.raises(RuntimeError, match="has 1 stateful loaders, but found 2 states"):
|
|
cl._load_state_dicts([state1, state2])
|
|
|
|
# 2 stateful loaders, 2 states to load
|
|
cl = CombinedLoader([range(2), stateful1, range(3), stateful2])
|
|
cl._load_state_dicts([state1, state2])
|
|
stateful1.load_state_dict.assert_called_with(state1)
|
|
stateful2.load_state_dict.assert_called_with(state2)
|