"sequential" mode for `CombinedLoader` (#16743)

This commit is contained in:
Carlos Mocholí 2023-02-14 06:51:48 +01:00 committed by GitHub
parent a342410e25
commit 6c037a479f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 209 additions and 156 deletions

View File

@ -81,7 +81,7 @@ class Top1:
class ProductionReadyModel(LitModule, ServableModule):
def configure_payload(self):
# 1: Access the train dataloader and load a single sample.
image, _ = self.trainer.train_dataloader.loaders.dataset[0]
image, _ = self.trainer.train_dataloader.iterables.dataset[0]
# 2: Convert the image into a PIL Image to bytes and encode it with base64
pil_image = T.ToPILImage()(image)

View File

@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))
- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
### Changed
- "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490))
@ -82,6 +84,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))
- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables`([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
### Deprecated
-

View File

@ -245,7 +245,7 @@ class DataConnector:
- Wrapping the dataloader based on strategy-specific logic
"""
if isinstance(dataloader, CombinedLoader):
for i, dl in enumerate(dataloader._loaders_flattened):
for i, dl in enumerate(dataloader._flattened):
dataloader._update_index(self._prepare_dataloader(dl, shuffle=shuffle, mode=mode), i)
return dataloader
@ -347,7 +347,7 @@ class DataConnector:
for loader in dataloaders:
apply_to_collection(
loader.loaders if isinstance(loader, CombinedLoader) else loader,
loader.iterables if isinstance(loader, CombinedLoader) else loader,
DataLoader,
self._check_eval_shuffling,
mode=mode,

View File

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Iterator, List, Literal, Optional, Sized, Type, TypeVar
from collections.abc import Iterable
from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter
from typing_extensions import Self, TypedDict
from lightning.fabric.utilities.data import sized_len
@ -73,6 +74,41 @@ class _MinSize(_ModeIterator[List]):
return [next(it) for it in self.iterators]
class _Sequential(_ModeIterator[Tuple[int, Any]]):
def __init__(self, iterables: List[Iterable]) -> None:
super().__init__(iterables)
self._iterator_idx = 0 # what would be dataloader_idx
self._idx = 0 # what would be batch_idx
def __next__(self) -> Tuple[int, Any]:
n = len(self.iterators)
if n == 0:
raise StopIteration
try:
out = next(self.iterators[self._iterator_idx])
index = self._idx
self._idx += 1
# the return is enumerated by default
return index, out
except StopIteration:
self._iterator_idx += 1
self._idx = 0
if self._iterator_idx >= n:
raise
return self.__next__()
def __iter__(self) -> Self: # type: ignore[valid-type]
super().__iter__()
self._iterator_idx = 0
self._idx = 0
return self
def reset(self) -> None:
super().reset()
self._iterator_idx = 0
self._idx = 0
class _CombinationMode(TypedDict):
fn: Callable[[List[int]], int]
iterator: Type[_ModeIterator]
@ -81,9 +117,10 @@ class _CombinationMode(TypedDict):
_supported_modes = {
"min_size": _CombinationMode(fn=min, iterator=_MinSize),
"max_size_cycle": _CombinationMode(fn=max, iterator=_MaxSizeCycle),
"sequential": _CombinationMode(fn=sum, iterator=_Sequential),
}
_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle"]
_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle", "sequential"]
class _CombinedDataset(Sized):
@ -114,20 +151,23 @@ class _CombinedDataset(Sized):
class CombinedLoader(Iterable):
"""Combines different dataloaders and allows sampling in parallel.
"""Combines different iterables under custom sampling modes.
Args:
loaders: the loaders to sample from. Can be all kind of collection
iterables: the loaders to sample from. Can be any kind of collection
mode:
* ``"min_size"``, which raises StopIteration after the shortest loader (the one with the lowest number of
batches) is done.
* ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is
done, while cycling through the shorter loaders.
* ``"min_size"``, which raises StopIteration after the shortest iterable (the one with the lowest number of
items) is done.
* ``"max_size_cycle"`` which raises StopIteration after the longest iterable (the one with most items) is
done, while cycling through rest of the iterables.
* ``"sequential"`` will consume ecah iterable sequentially, and returns a tuple with the associated index
from each iterable.
Examples:
>>> loaders = {'a': DataLoader(range(6), batch_size=4),
... 'b': DataLoader(range(15), batch_size=5)}
>>> combined_loader = CombinedLoader(loaders, 'max_size_cycle')
>>> from torch.utils.data import DataLoader
>>> iterables = {'a': DataLoader(range(6), batch_size=4),
... 'b': DataLoader(range(15), batch_size=5)}
>>> combined_loader = CombinedLoader(iterables, 'max_size_cycle')
>>> len(combined_loader)
3
>>> for item in combined_loader:
@ -135,26 +175,33 @@ class CombinedLoader(Iterable):
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
>>> combined_loader = CombinedLoader(loaders, 'min_size')
>>> combined_loader = CombinedLoader(iterables, 'min_size')
>>> len(combined_loader)
2
>>> for item in combined_loader:
... print(item)
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
>>> combined_loader = CombinedLoader(iterables, 'sequential')
>>> len(combined_loader)
5
>>> for item in combined_loader:
... print(*item)
0 tensor([0, 1, 2, 3])
1 tensor([4, 5])
0 tensor([0, 1, 2, 3, 4])
1 tensor([5, 6, 7, 8, 9])
2 tensor([10, 11, 12, 13, 14])
"""
def __init__(self, loaders: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
if mode not in _supported_modes:
raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_supported_modes)}.")
# TODO(carmocca): rename loaders to iterables
self._loaders = loaders
self._loaders_flattened, self._loaders_spec = _tree_flatten(loaders)
self._iterables = iterables
self._flattened, self._spec = _tree_flatten(iterables)
# TODO(carmocca): doing this might not be necessary
datasets = _map_and_unflatten(
lambda x: getattr(x, "dataset", None), self._loaders_flattened, self._loaders_spec
)
datasets = _map_and_unflatten(lambda x: getattr(x, "dataset", None), self._flattened, self._spec)
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
self.dataset = _CombinedDataset(datasets, mode)
@ -162,30 +209,30 @@ class CombinedLoader(Iterable):
self._iterator: Optional[_ModeIterator] = None
@property
def loaders(self) -> Any:
"""Return the original collection of loaders."""
return self._loaders
def iterables(self) -> Any:
"""Return the original collection of iterables."""
return self._iterables
@property
def sampler(self) -> Any:
"""Return a collections of samplers extracted from loaders."""
return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._loaders_flattened, self._loaders_spec)
"""Return a collections of samplers extracted from iterables."""
return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._flattened, self._spec)
@property
def batch_sampler(self) -> Any:
"""Return a collections of batch samplers extracted from loaders."""
return _map_and_unflatten(
lambda x: getattr(x, "batch_sampler", None), self._loaders_flattened, self._loaders_spec
)
"""Return a collections of batch samplers extracted from iterables."""
return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self._flattened, self._spec)
def __next__(self) -> Any:
assert self._iterator is not None
out = next(self._iterator)
return tree_unflatten(out, self._loaders_spec)
if isinstance(self._iterator, _Sequential):
return out
return tree_unflatten(out, self._spec)
def __iter__(self) -> Self: # type: ignore[valid-type]
cls = _supported_modes[self._mode]["iterator"]
iterator = cls(self._loaders_flattened)
iterator = cls(self._flattened)
iter(iterator)
self._iterator = iterator
return self
@ -193,7 +240,7 @@ class CombinedLoader(Iterable):
def __len__(self) -> int:
"""Compute the number of batches."""
lengths = []
for dl in self._loaders_flattened:
for dl in self._flattened:
length = sized_len(dl)
if length is None:
raise NotImplementedError(f"`{type(dl).__name__}` does not define `__len__`")
@ -205,16 +252,16 @@ class CombinedLoader(Iterable):
if self._iterator is not None:
self._iterator.reset()
self._iterator = None
for loader in self._loaders_flattened:
_shutdown_workers_and_reset_iterator(loader)
for iterable in self._flattened:
_shutdown_workers_and_reset_iterator(iterable)
def _update_index(self, dataloader: Iterable, index: int) -> None:
# mutation needs to be done using this method to avoid stale references
self._loaders_flattened[index] = dataloader
self._loaders = tree_unflatten(self._loaders_flattened, self._loaders_spec)
self._flattened[index] = dataloader
self._iterables = tree_unflatten(self._flattened, self._spec)
def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None:
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
if hasattr(dataloader, "_iterator"):
if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
dataloader._iterator._shutdown_workers()

View File

@ -285,11 +285,8 @@ class Trainer:
enable_model_summary: Whether to enable model summarization by default.
Default: ``True``.
multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets.
Default: ``"max_size_cycle"``.
multiple_trainloader_mode: How to loop over the datasets when there are multiple iterables.
See :class:`lightning.pytorch.trainer.supporters.CombinedLoader`.
inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
evaluation (``validate``/``test``/``predict``).
@ -1033,7 +1030,7 @@ class Trainer:
mode=RunningStage.TRAINING,
)
loaders = (
self.train_dataloader.loaders
self.train_dataloader.iterables
if isinstance(self.train_dataloader, CombinedLoader)
else self.train_dataloader
)
@ -1044,7 +1041,7 @@ class Trainer:
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(loaders, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
# wrap the sequence of train iterables to a CombinedLoader object for computing the num_training_batches
if not isinstance(self.train_dataloader, CombinedLoader):
self.train_dataloader = CombinedLoader(loaders, self._data_connector.multiple_trainloader_mode)

View File

@ -365,7 +365,7 @@ def test_manual_poptorch_dataloader(tmpdir):
assert isinstance(trainer.strategy, IPUStrategy)
assert trainer.strategy.training_opts is other_options
dataloader = trainer.train_dataloader.loaders
dataloader = trainer.train_dataloader.iterables
assert dataloader is model.poptorch_dataloader # exact object, was not recreated
# dataloader uses the options in the model, not the strategy
assert dataloader.options is model_options
@ -393,7 +393,7 @@ def test_manual_poptorch_opts(tmpdir):
assert trainer.strategy.training_opts == training_opts
assert trainer.strategy.inference_opts == inference_opts
dataloader = trainer.train_dataloader.loaders
dataloader = trainer.train_dataloader.iterables
assert isinstance(dataloader, poptorch.DataLoader)
assert dataloader.options == training_opts
assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting
@ -427,7 +427,7 @@ def test_manual_poptorch_opts_custom(tmpdir):
val_dataloader = trainer.val_dataloaders[0]
train_dataloader = trainer.train_dataloader
assert isinstance(train_dataloader, CombinedLoader)
train_dataloader = train_dataloader.loaders
train_dataloader = train_dataloader.iterables
assert isinstance(val_dataloader, poptorch.DataLoader)
assert isinstance(train_dataloader, poptorch.DataLoader)
assert train_dataloader.options.replication_factor == 2

View File

@ -116,7 +116,7 @@ class TestSpawnBoringModel(BoringModel):
def on_train_end(self):
def _get_warning_msg():
dl = self.trainer.train_dataloader.loaders
dl = self.trainer.train_dataloader.iterables
if hasattr(dl, "persistent_workers"):
if self.num_workers == 0:
warn_str = "Consider setting num_workers>0 and persistent_workers=True"
@ -295,7 +295,7 @@ def test_dataloader_reinit_for_subclass():
class LoaderTestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert len(self.trainer.train_dataloader.loaders) == 10
assert len(self.trainer.train_dataloader.iterables) == 10
return super().training_step(batch, batch_idx)
def validation_step(self, batch, batch_idx):

View File

@ -74,7 +74,7 @@ def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir):
with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"):
trainer.fit(model)
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
assert isinstance(trainer.train_dataloader.iterables.sampler, SequentialSampler)
assert isinstance(trainer.val_dataloaders[0].sampler, SequentialSampler)
@ -161,6 +161,6 @@ def test_distributed_sampler_with_overfit_batches():
trainer.strategy._lightning_module = model
trainer._data_connector.attach_dataloaders(model)
trainer.reset_train_dataloader()
train_sampler = trainer.train_dataloader.loaders.sampler
train_sampler = trainer.train_dataloader.iterables.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle is False

View File

@ -160,7 +160,7 @@ def test_train_dataloader_passed_to_fit(tmpdir):
fit_options = dict(train_dataloaders=train_loader)
trainer.fit(model, **fit_options)
assert trainer.num_training_batches == 2
assert trainer.train_dataloader.loaders == train_loader
assert trainer.train_dataloader.iterables == train_loader
assert trainer.state.finished, f"Training failed with {trainer.state}"
@ -836,7 +836,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir):
[("min_size", 16), ("max_size_cycle", 64)],
)
def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches):
"""Integration test for multiple train loaders."""
"""Integration test for multiple train iterables."""
class CustomBoringModel(BoringModel):
def train_dataloader(self):
@ -1178,12 +1178,12 @@ def test_dataloaders_reset_and_attach(tmpdir):
# 1st fit
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset
assert trainer.train_dataloader.iterables.dataset is dataloader_0.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
# 2nd fit
trainer.fit_loop.max_steps += 1
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset
assert trainer.train_dataloader.iterables.dataset is dataloader_2.dataset
assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset
# 1st validate
@ -1316,7 +1316,7 @@ def test_request_dataloader(tmpdir):
return DataLoaderWrapper(loader)
def on_train_batch_start(self, batch, batch_idx: int) -> None:
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
assert isinstance(self.trainer.train_dataloader.iterables, DataLoaderWrapper)
self.on_train_batch_start_called = True
def val_dataloader(self):

View File

@ -18,6 +18,7 @@ from unittest import mock
import pytest
import torch
from torch import Tensor
from torch.utils._pytree import tree_flatten
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.dataset import Dataset, IterableDataset
@ -26,7 +27,14 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.trainer.supporters import _CombinedDataset, CombinedLoader
from lightning.pytorch.trainer.supporters import (
_CombinedDataset,
_MaxSizeCycle,
_MinSize,
_Sequential,
_supported_modes,
CombinedLoader,
)
from tests_pytorch.helpers.runif import RunIf
@ -77,32 +85,85 @@ def test_combined_dataset_no_length():
len(cd)
def test_combined_loader_dict_min_size():
"""Test `CombinedLoaderIterator` given mapping loaders."""
loaders = {
def test_combined_loader_modes():
"""Test `CombinedLoaderIterator` given mapping iterables."""
iterables = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
cl = CombinedLoader(loaders, "min_size")
lengths = [len(v) for v in iterables.values()]
for idx, item in enumerate(cl):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]
combined_loader = CombinedLoader(loaders, "min_size")
assert len(combined_loader) == min(len(v) for v in loaders)
# min_size with dict
min_len = min(lengths)
combined_loader = CombinedLoader(iterables, "min_size")
assert combined_loader._iterator is None
assert len(combined_loader) == min_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MinSize)
assert isinstance(item, dict)
assert list(item) == ["a", "b"]
assert idx == min_len - 1
assert idx == len(combined_loader) - 1
# max_size_cycle with dict
max_len = max(lengths)
combined_loader = CombinedLoader(iterables, "max_size_cycle")
assert combined_loader._iterator is None
assert len(combined_loader) == max_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
assert isinstance(item, dict)
assert list(item) == ["a", "b"]
assert idx == max_len - 1
assert idx == len(combined_loader) - 1
# sequential with dict
sum_len = sum(lengths)
combined_loader = CombinedLoader(iterables, "sequential")
assert combined_loader._iterator is None
assert len(combined_loader) == sum_len
for total_idx, (idx, item) in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _Sequential)
assert isinstance(idx, int)
assert isinstance(item, Tensor)
assert idx == lengths[-1] - 1
assert total_idx == sum_len - 1
assert total_idx == len(combined_loader) - 1
iterables = list(iterables.values())
# min_size with list
combined_loader = CombinedLoader(iterables, "min_size")
assert len(combined_loader) == min_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MinSize)
assert isinstance(item, list)
assert len(item) == 2
assert idx == min_len - 1
assert idx == len(combined_loader) - 1
# max_size_cycle with list
combined_loader = CombinedLoader(iterables, "max_size_cycle")
assert len(combined_loader) == max_len
for idx, item in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
assert isinstance(item, list)
assert len(item) == 2
assert idx == max_len - 1
assert idx == len(combined_loader) - 1
# sequential with list
combined_loader = CombinedLoader(iterables, "sequential")
assert combined_loader._iterator is None
assert len(combined_loader) == sum_len
for total_idx, (idx, item) in enumerate(combined_loader):
assert isinstance(combined_loader._iterator, _Sequential)
assert isinstance(idx, int)
assert isinstance(item, Tensor)
assert idx == lengths[-1] - 1
assert total_idx == sum_len - 1
assert total_idx == len(combined_loader) - 1
def test_combined_loader_raises():
with pytest.raises(ValueError, match="Unsupported mode 'testtt'"):
@ -113,43 +174,6 @@ def test_combined_loader_raises():
len(combined_loader)
def test_combined_loader_dict_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders."""
loaders = {
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
}
combined_loader = CombinedLoader(loaders, "max_size_cycle")
assert len(combined_loader) == max(len(v) for v in loaders.values())
for idx, item in enumerate(combined_loader):
assert isinstance(item, dict)
assert len(item) == 2
assert "a" in item and "b" in item
assert idx == len(combined_loader) - 1
def test_combined_loader_sequence_min_size():
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]
combined_loader = CombinedLoader(loaders, "min_size")
assert len(combined_loader) == min(len(v) for v in loaders)
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2
assert idx == len(combined_loader) - 1
class TestIterableDataset(IterableDataset):
def __init__(self, size: int = 10):
self.size = size
@ -163,10 +187,10 @@ class TestIterableDataset(IterableDataset):
return next(self.sampler_iter)
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"])
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "sequential"])
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
"""Test `CombinedLoader` of mode 'min_size' given sequence iterables."""
if use_multiple_dataloaders:
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
@ -176,11 +200,9 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
loaders = [
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
]
combined_loader = CombinedLoader(loaders, mode)
has_break = False
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2 if use_multiple_dataloaders else 1
@ -190,8 +212,13 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
if mode == "max_size_cycle":
assert all(combined_loader._iterator._consumed) == (not has_break)
expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5
assert (expected - 1) == idx, (mode, use_multiple_dataloaders)
expected = 5
if use_multiple_dataloaders:
if mode == "max_size_cycle":
expected = 10
elif mode == "sequential":
expected = 15
assert idx == expected - 1
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
@ -221,28 +248,8 @@ def test_combined_loader_sequence_with_map_and_iterable(lengths):
x, y = lengths
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
counter = 0
for _ in dataloader:
counter += 1
assert counter == max(x, y)
def test_combined_loader_sequence_max_size_cycle():
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders."""
loaders = [
torch.utils.data.DataLoader(range(10), batch_size=4),
torch.utils.data.DataLoader(range(20), batch_size=5),
]
combined_loader = CombinedLoader(loaders, "max_size_cycle")
assert len(combined_loader) == max(len(v) for v in loaders)
for idx, item in enumerate(combined_loader):
assert isinstance(item, Sequence)
assert len(item) == 2
assert idx == len(combined_loader) - 1
seen = sum(1 for _ in dataloader)
assert seen == max(x, y)
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
@ -340,22 +347,19 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_s
)
with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"):
len(dataloader)
assert len(dataloader.loaders["b"]) == 8
assert len(dataloader.iterables["b"]) == 8
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
assert len(dataloader.loaders["b"]) == 4 if replace_sampler_ddp else 8
assert len(dataloader.iterables["b"]) == 4 if replace_sampler_ddp else 8
with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"):
len(dataloader)
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
@pytest.mark.parametrize("is_min_size_mode", [False, True])
@pytest.mark.parametrize("mode", ("min_size", "max_size_cycle", "sequential"))
@pytest.mark.parametrize("use_combined_loader", [False, True])
def test_combined_dataloader_for_training_with_ddp(
replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool
):
def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, use_combined_loader):
"""When providing a CombinedLoader as the training data, it should be correctly receive the distributed
samplers."""
mode = "min_size" if is_min_size_mode else "max_size_cycle"
dim = 3
n1 = 8
n2 = 6
@ -371,12 +375,13 @@ def test_combined_dataloader_for_training_with_ddp(
accelerator="auto",
devices="auto",
replace_sampler_ddp=replace_sampler_ddp,
multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode,
multiple_trainloader_mode=mode,
)
trainer._data_connector.attach_data(
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
)
expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2)
fn = _supported_modes[mode]["fn"]
expected_length_before_ddp = fn([n1, n2])
expected_length_after_ddp = (
math.ceil(expected_length_before_ddp / trainer.num_devices)
if replace_sampler_ddp

View File

@ -73,10 +73,10 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b
assert model.batch_size == new_batch_size
if dm_bs == -1:
# datamodule batch size takes precedence
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
assert trainer.train_dataloader.iterables.batch_size == new_batch_size
if dm_bs not in (-1, None):
assert datamodule.batch_size == new_batch_size
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
assert trainer.train_dataloader.iterables.batch_size == new_batch_size
@pytest.mark.parametrize("trainer_fn", ["fit", "validate", "test", "predict"])
@ -312,7 +312,7 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs)
assert advance_mocked.call_count == max_trials
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
assert trainer.train_dataloader.iterables.batch_size == new_batch_size
assert trainer.val_dataloaders[0].batch_size == init_batch_size
@ -357,7 +357,7 @@ def test_batch_size_finder_callback(tmpdir, trainer_fn):
loop = getattr(trainer, f"{trainer_fn}_loop")
if trainer_fn == "fit":
expected_steps = trainer.train_dataloader.loaders.dataset.len // after_batch_size
expected_steps = trainer.train_dataloader.iterables.dataset.len // after_batch_size
assert trainer.global_step == expected_steps * max_epochs
assert trainer.current_epoch == max_epochs
assert loop.epoch_loop.batch_progress.total.completed == expected_steps * max_epochs
@ -466,4 +466,4 @@ def test_dataloader_batch_size_updated_on_failure(_, tmpdir, scale_method, expec
new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs)
assert new_batch_size == model.batch_size
assert new_batch_size == expected_batch_size
assert trainer.train_dataloader.loaders.batch_size == expected_batch_size
assert trainer.train_dataloader.iterables.batch_size == expected_batch_size