diff --git a/examples/pl_servable_module/production.py b/examples/pl_servable_module/production.py index 7b086c5faf..22174729d4 100644 --- a/examples/pl_servable_module/production.py +++ b/examples/pl_servable_module/production.py @@ -81,7 +81,7 @@ class Top1: class ProductionReadyModel(LitModule, ServableModule): def configure_payload(self): # 1: Access the train dataloader and load a single sample. - image, _ = self.trainer.train_dataloader.loaders.dataset[0] + image, _ = self.trainer.train_dataloader.iterables.dataset[0] # 2: Convert the image into a PIL Image to bytes and encode it with base64 pil_image = T.ToPILImage()(image) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index ffc48e7fa2..30d0b42258 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646)) +- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743)) + ### Changed - "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490)) @@ -82,6 +84,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365)) +- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables`([#16743](https://github.com/Lightning-AI/lightning/pull/16743)) + ### Deprecated - diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 4dd4a1adb4..8f16fb82a0 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -245,7 +245,7 @@ class DataConnector: - Wrapping the dataloader based on strategy-specific logic """ if isinstance(dataloader, CombinedLoader): - for i, dl in enumerate(dataloader._loaders_flattened): + for i, dl in enumerate(dataloader._flattened): dataloader._update_index(self._prepare_dataloader(dl, shuffle=shuffle, mode=mode), i) return dataloader @@ -347,7 +347,7 @@ class DataConnector: for loader in dataloaders: apply_to_collection( - loader.loaders if isinstance(loader, CombinedLoader) else loader, + loader.iterables if isinstance(loader, CombinedLoader) else loader, DataLoader, self._check_eval_shuffling, mode=mode, diff --git a/src/lightning/pytorch/trainer/supporters.py b/src/lightning/pytorch/trainer/supporters.py index e33a932235..2c3872e239 100644 --- a/src/lightning/pytorch/trainer/supporters.py +++ b/src/lightning/pytorch/trainer/supporters.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, Iterator, List, Literal, Optional, Sized, Type, TypeVar +from collections.abc import Iterable +from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar -from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader +from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter from typing_extensions import Self, TypedDict from lightning.fabric.utilities.data import sized_len @@ -73,6 +74,41 @@ class _MinSize(_ModeIterator[List]): return [next(it) for it in self.iterators] +class _Sequential(_ModeIterator[Tuple[int, Any]]): + def __init__(self, iterables: List[Iterable]) -> None: + super().__init__(iterables) + self._iterator_idx = 0 # what would be dataloader_idx + self._idx = 0 # what would be batch_idx + + def __next__(self) -> Tuple[int, Any]: + n = len(self.iterators) + if n == 0: + raise StopIteration + try: + out = next(self.iterators[self._iterator_idx]) + index = self._idx + self._idx += 1 + # the return is enumerated by default + return index, out + except StopIteration: + self._iterator_idx += 1 + self._idx = 0 + if self._iterator_idx >= n: + raise + return self.__next__() + + def __iter__(self) -> Self: # type: ignore[valid-type] + super().__iter__() + self._iterator_idx = 0 + self._idx = 0 + return self + + def reset(self) -> None: + super().reset() + self._iterator_idx = 0 + self._idx = 0 + + class _CombinationMode(TypedDict): fn: Callable[[List[int]], int] iterator: Type[_ModeIterator] @@ -81,9 +117,10 @@ class _CombinationMode(TypedDict): _supported_modes = { "min_size": _CombinationMode(fn=min, iterator=_MinSize), "max_size_cycle": _CombinationMode(fn=max, iterator=_MaxSizeCycle), + "sequential": _CombinationMode(fn=sum, iterator=_Sequential), } -_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle"] +_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle", "sequential"] class _CombinedDataset(Sized): @@ -114,20 +151,23 @@ class _CombinedDataset(Sized): class CombinedLoader(Iterable): - """Combines different dataloaders and allows sampling in parallel. + """Combines different iterables under custom sampling modes. Args: - loaders: the loaders to sample from. Can be all kind of collection + iterables: the loaders to sample from. Can be any kind of collection mode: - * ``"min_size"``, which raises StopIteration after the shortest loader (the one with the lowest number of - batches) is done. - * ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is - done, while cycling through the shorter loaders. + * ``"min_size"``, which raises StopIteration after the shortest iterable (the one with the lowest number of + items) is done. + * ``"max_size_cycle"`` which raises StopIteration after the longest iterable (the one with most items) is + done, while cycling through rest of the iterables. + * ``"sequential"`` will consume ecah iterable sequentially, and returns a tuple with the associated index + from each iterable. Examples: - >>> loaders = {'a': DataLoader(range(6), batch_size=4), - ... 'b': DataLoader(range(15), batch_size=5)} - >>> combined_loader = CombinedLoader(loaders, 'max_size_cycle') + >>> from torch.utils.data import DataLoader + >>> iterables = {'a': DataLoader(range(6), batch_size=4), + ... 'b': DataLoader(range(15), batch_size=5)} + >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') >>> len(combined_loader) 3 >>> for item in combined_loader: @@ -135,26 +175,33 @@ class CombinedLoader(Iterable): {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])} - >>> combined_loader = CombinedLoader(loaders, 'min_size') + >>> combined_loader = CombinedLoader(iterables, 'min_size') >>> len(combined_loader) 2 >>> for item in combined_loader: ... print(item) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} + >>> combined_loader = CombinedLoader(iterables, 'sequential') + >>> len(combined_loader) + 5 + >>> for item in combined_loader: + ... print(*item) + 0 tensor([0, 1, 2, 3]) + 1 tensor([4, 5]) + 0 tensor([0, 1, 2, 3, 4]) + 1 tensor([5, 6, 7, 8, 9]) + 2 tensor([10, 11, 12, 13, 14]) """ - def __init__(self, loaders: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None: + def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None: if mode not in _supported_modes: raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_supported_modes)}.") - # TODO(carmocca): rename loaders to iterables - self._loaders = loaders - self._loaders_flattened, self._loaders_spec = _tree_flatten(loaders) + self._iterables = iterables + self._flattened, self._spec = _tree_flatten(iterables) # TODO(carmocca): doing this might not be necessary - datasets = _map_and_unflatten( - lambda x: getattr(x, "dataset", None), self._loaders_flattened, self._loaders_spec - ) + datasets = _map_and_unflatten(lambda x: getattr(x, "dataset", None), self._flattened, self._spec) # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader self.dataset = _CombinedDataset(datasets, mode) @@ -162,30 +209,30 @@ class CombinedLoader(Iterable): self._iterator: Optional[_ModeIterator] = None @property - def loaders(self) -> Any: - """Return the original collection of loaders.""" - return self._loaders + def iterables(self) -> Any: + """Return the original collection of iterables.""" + return self._iterables @property def sampler(self) -> Any: - """Return a collections of samplers extracted from loaders.""" - return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._loaders_flattened, self._loaders_spec) + """Return a collections of samplers extracted from iterables.""" + return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._flattened, self._spec) @property def batch_sampler(self) -> Any: - """Return a collections of batch samplers extracted from loaders.""" - return _map_and_unflatten( - lambda x: getattr(x, "batch_sampler", None), self._loaders_flattened, self._loaders_spec - ) + """Return a collections of batch samplers extracted from iterables.""" + return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self._flattened, self._spec) def __next__(self) -> Any: assert self._iterator is not None out = next(self._iterator) - return tree_unflatten(out, self._loaders_spec) + if isinstance(self._iterator, _Sequential): + return out + return tree_unflatten(out, self._spec) def __iter__(self) -> Self: # type: ignore[valid-type] cls = _supported_modes[self._mode]["iterator"] - iterator = cls(self._loaders_flattened) + iterator = cls(self._flattened) iter(iterator) self._iterator = iterator return self @@ -193,7 +240,7 @@ class CombinedLoader(Iterable): def __len__(self) -> int: """Compute the number of batches.""" lengths = [] - for dl in self._loaders_flattened: + for dl in self._flattened: length = sized_len(dl) if length is None: raise NotImplementedError(f"`{type(dl).__name__}` does not define `__len__`") @@ -205,16 +252,16 @@ class CombinedLoader(Iterable): if self._iterator is not None: self._iterator.reset() self._iterator = None - for loader in self._loaders_flattened: - _shutdown_workers_and_reset_iterator(loader) + for iterable in self._flattened: + _shutdown_workers_and_reset_iterator(iterable) def _update_index(self, dataloader: Iterable, index: int) -> None: # mutation needs to be done using this method to avoid stale references - self._loaders_flattened[index] = dataloader - self._loaders = tree_unflatten(self._loaders_flattened, self._loaders_spec) + self._flattened[index] = dataloader + self._iterables = tree_unflatten(self._flattened, self._spec) -def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None: +def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: if hasattr(dataloader, "_iterator"): if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): dataloader._iterator._shutdown_workers() diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 70c0c98887..75c3caea62 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -285,11 +285,8 @@ class Trainer: enable_model_summary: Whether to enable model summarization by default. Default: ``True``. - multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders. - In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, - and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets - reload when reaching the minimum length of datasets. - Default: ``"max_size_cycle"``. + multiple_trainloader_mode: How to loop over the datasets when there are multiple iterables. + See :class:`lightning.pytorch.trainer.supporters.CombinedLoader`. inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during evaluation (``validate``/``test``/``predict``). @@ -1033,7 +1030,7 @@ class Trainer: mode=RunningStage.TRAINING, ) loaders = ( - self.train_dataloader.loaders + self.train_dataloader.iterables if isinstance(self.train_dataloader, CombinedLoader) else self.train_dataloader ) @@ -1044,7 +1041,7 @@ class Trainer: # add worker_init_fn for correct seeding in worker processes apply_to_collection(loaders, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank) - # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches + # wrap the sequence of train iterables to a CombinedLoader object for computing the num_training_batches if not isinstance(self.train_dataloader, CombinedLoader): self.train_dataloader = CombinedLoader(loaders, self._data_connector.multiple_trainloader_mode) diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index fa1bf7c0b7..fb3cbe8ce7 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -365,7 +365,7 @@ def test_manual_poptorch_dataloader(tmpdir): assert isinstance(trainer.strategy, IPUStrategy) assert trainer.strategy.training_opts is other_options - dataloader = trainer.train_dataloader.loaders + dataloader = trainer.train_dataloader.iterables assert dataloader is model.poptorch_dataloader # exact object, was not recreated # dataloader uses the options in the model, not the strategy assert dataloader.options is model_options @@ -393,7 +393,7 @@ def test_manual_poptorch_opts(tmpdir): assert trainer.strategy.training_opts == training_opts assert trainer.strategy.inference_opts == inference_opts - dataloader = trainer.train_dataloader.loaders + dataloader = trainer.train_dataloader.iterables assert isinstance(dataloader, poptorch.DataLoader) assert dataloader.options == training_opts assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting @@ -427,7 +427,7 @@ def test_manual_poptorch_opts_custom(tmpdir): val_dataloader = trainer.val_dataloaders[0] train_dataloader = trainer.train_dataloader assert isinstance(train_dataloader, CombinedLoader) - train_dataloader = train_dataloader.loaders + train_dataloader = train_dataloader.iterables assert isinstance(val_dataloader, poptorch.DataLoader) assert isinstance(train_dataloader, poptorch.DataLoader) assert train_dataloader.options.replication_factor == 2 diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index e183a9252a..d265172085 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -116,7 +116,7 @@ class TestSpawnBoringModel(BoringModel): def on_train_end(self): def _get_warning_msg(): - dl = self.trainer.train_dataloader.loaders + dl = self.trainer.train_dataloader.iterables if hasattr(dl, "persistent_workers"): if self.num_workers == 0: warn_str = "Consider setting num_workers>0 and persistent_workers=True" @@ -295,7 +295,7 @@ def test_dataloader_reinit_for_subclass(): class LoaderTestModel(BoringModel): def training_step(self, batch, batch_idx): - assert len(self.trainer.train_dataloader.loaders) == 10 + assert len(self.trainer.train_dataloader.iterables) == 10 return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 6e34560b38..36a75717c6 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -74,7 +74,7 @@ def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir): with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"): trainer.fit(model) - assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler) + assert isinstance(trainer.train_dataloader.iterables.sampler, SequentialSampler) assert isinstance(trainer.val_dataloaders[0].sampler, SequentialSampler) @@ -161,6 +161,6 @@ def test_distributed_sampler_with_overfit_batches(): trainer.strategy._lightning_module = model trainer._data_connector.attach_dataloaders(model) trainer.reset_train_dataloader() - train_sampler = trainer.train_dataloader.loaders.sampler + train_sampler = trainer.train_dataloader.iterables.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle is False diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 1307c3d823..adbb10d12e 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -160,7 +160,7 @@ def test_train_dataloader_passed_to_fit(tmpdir): fit_options = dict(train_dataloaders=train_loader) trainer.fit(model, **fit_options) assert trainer.num_training_batches == 2 - assert trainer.train_dataloader.loaders == train_loader + assert trainer.train_dataloader.iterables == train_loader assert trainer.state.finished, f"Training failed with {trainer.state}" @@ -836,7 +836,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir): [("min_size", 16), ("max_size_cycle", 64)], ) def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches): - """Integration test for multiple train loaders.""" + """Integration test for multiple train iterables.""" class CustomBoringModel(BoringModel): def train_dataloader(self): @@ -1178,12 +1178,12 @@ def test_dataloaders_reset_and_attach(tmpdir): # 1st fit trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1) - assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset + assert trainer.train_dataloader.iterables.dataset is dataloader_0.dataset assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset # 2nd fit trainer.fit_loop.max_steps += 1 trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3) - assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset + assert trainer.train_dataloader.iterables.dataset is dataloader_2.dataset assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset # 1st validate @@ -1316,7 +1316,7 @@ def test_request_dataloader(tmpdir): return DataLoaderWrapper(loader) def on_train_batch_start(self, batch, batch_idx: int) -> None: - assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) + assert isinstance(self.trainer.train_dataloader.iterables, DataLoaderWrapper) self.on_train_batch_start_called = True def val_dataloader(self): diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/trainer/test_supporters.py index bed648f468..01025975a2 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/trainer/test_supporters.py @@ -18,6 +18,7 @@ from unittest import mock import pytest import torch +from torch import Tensor from torch.utils._pytree import tree_flatten from torch.utils.data import DataLoader, TensorDataset from torch.utils.data.dataset import Dataset, IterableDataset @@ -26,7 +27,14 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.trainer.supporters import _CombinedDataset, CombinedLoader +from lightning.pytorch.trainer.supporters import ( + _CombinedDataset, + _MaxSizeCycle, + _MinSize, + _Sequential, + _supported_modes, + CombinedLoader, +) from tests_pytorch.helpers.runif import RunIf @@ -77,32 +85,85 @@ def test_combined_dataset_no_length(): len(cd) -def test_combined_loader_dict_min_size(): - """Test `CombinedLoaderIterator` given mapping loaders.""" - loaders = { +def test_combined_loader_modes(): + """Test `CombinedLoaderIterator` given mapping iterables.""" + iterables = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), } - cl = CombinedLoader(loaders, "min_size") + lengths = [len(v) for v in iterables.values()] - for idx, item in enumerate(cl): - assert isinstance(item, dict) - assert len(item) == 2 - assert "a" in item and "b" in item - assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1 - - loaders = [ - torch.utils.data.DataLoader(range(10), batch_size=4), - torch.utils.data.DataLoader(range(20), batch_size=5), - ] - combined_loader = CombinedLoader(loaders, "min_size") - - assert len(combined_loader) == min(len(v) for v in loaders) + # min_size with dict + min_len = min(lengths) + combined_loader = CombinedLoader(iterables, "min_size") + assert combined_loader._iterator is None + assert len(combined_loader) == min_len for idx, item in enumerate(combined_loader): + assert isinstance(combined_loader._iterator, _MinSize) + assert isinstance(item, dict) + assert list(item) == ["a", "b"] + assert idx == min_len - 1 + assert idx == len(combined_loader) - 1 + + # max_size_cycle with dict + max_len = max(lengths) + combined_loader = CombinedLoader(iterables, "max_size_cycle") + assert combined_loader._iterator is None + assert len(combined_loader) == max_len + for idx, item in enumerate(combined_loader): + assert isinstance(combined_loader._iterator, _MaxSizeCycle) + assert isinstance(item, dict) + assert list(item) == ["a", "b"] + assert idx == max_len - 1 + assert idx == len(combined_loader) - 1 + + # sequential with dict + sum_len = sum(lengths) + combined_loader = CombinedLoader(iterables, "sequential") + assert combined_loader._iterator is None + assert len(combined_loader) == sum_len + for total_idx, (idx, item) in enumerate(combined_loader): + assert isinstance(combined_loader._iterator, _Sequential) + assert isinstance(idx, int) + assert isinstance(item, Tensor) + assert idx == lengths[-1] - 1 + assert total_idx == sum_len - 1 + assert total_idx == len(combined_loader) - 1 + + iterables = list(iterables.values()) + + # min_size with list + combined_loader = CombinedLoader(iterables, "min_size") + assert len(combined_loader) == min_len + for idx, item in enumerate(combined_loader): + assert isinstance(combined_loader._iterator, _MinSize) assert isinstance(item, list) assert len(item) == 2 + assert idx == min_len - 1 assert idx == len(combined_loader) - 1 + # max_size_cycle with list + combined_loader = CombinedLoader(iterables, "max_size_cycle") + assert len(combined_loader) == max_len + for idx, item in enumerate(combined_loader): + assert isinstance(combined_loader._iterator, _MaxSizeCycle) + assert isinstance(item, list) + assert len(item) == 2 + assert idx == max_len - 1 + assert idx == len(combined_loader) - 1 + + # sequential with list + combined_loader = CombinedLoader(iterables, "sequential") + assert combined_loader._iterator is None + assert len(combined_loader) == sum_len + for total_idx, (idx, item) in enumerate(combined_loader): + assert isinstance(combined_loader._iterator, _Sequential) + assert isinstance(idx, int) + assert isinstance(item, Tensor) + assert idx == lengths[-1] - 1 + assert total_idx == sum_len - 1 + assert total_idx == len(combined_loader) - 1 + def test_combined_loader_raises(): with pytest.raises(ValueError, match="Unsupported mode 'testtt'"): @@ -113,43 +174,6 @@ def test_combined_loader_raises(): len(combined_loader) -def test_combined_loader_dict_max_size_cycle(): - """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders.""" - loaders = { - "a": torch.utils.data.DataLoader(range(10), batch_size=4), - "b": torch.utils.data.DataLoader(range(20), batch_size=5), - } - - combined_loader = CombinedLoader(loaders, "max_size_cycle") - - assert len(combined_loader) == max(len(v) for v in loaders.values()) - - for idx, item in enumerate(combined_loader): - assert isinstance(item, dict) - assert len(item) == 2 - assert "a" in item and "b" in item - - assert idx == len(combined_loader) - 1 - - -def test_combined_loader_sequence_min_size(): - """Test `CombinedLoader` of mode 'min_size' given sequence loaders.""" - loaders = [ - torch.utils.data.DataLoader(range(10), batch_size=4), - torch.utils.data.DataLoader(range(20), batch_size=5), - ] - - combined_loader = CombinedLoader(loaders, "min_size") - - assert len(combined_loader) == min(len(v) for v in loaders) - - for idx, item in enumerate(combined_loader): - assert isinstance(item, Sequence) - assert len(item) == 2 - - assert idx == len(combined_loader) - 1 - - class TestIterableDataset(IterableDataset): def __init__(self, size: int = 10): self.size = size @@ -163,10 +187,10 @@ class TestIterableDataset(IterableDataset): return next(self.sampler_iter) -@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"]) +@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "sequential"]) @pytest.mark.parametrize("use_multiple_dataloaders", [False, True]) def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders): - """Test `CombinedLoader` of mode 'min_size' given sequence loaders.""" + """Test `CombinedLoader` of mode 'min_size' given sequence iterables.""" if use_multiple_dataloaders: loaders = [ torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), @@ -176,11 +200,9 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader loaders = [ torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), ] - combined_loader = CombinedLoader(loaders, mode) has_break = False - for idx, item in enumerate(combined_loader): assert isinstance(item, Sequence) assert len(item) == 2 if use_multiple_dataloaders else 1 @@ -190,8 +212,13 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader if mode == "max_size_cycle": assert all(combined_loader._iterator._consumed) == (not has_break) - expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5 - assert (expected - 1) == idx, (mode, use_multiple_dataloaders) + expected = 5 + if use_multiple_dataloaders: + if mode == "max_size_cycle": + expected = 10 + elif mode == "sequential": + expected = 15 + assert idx == expected - 1 @pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]]) @@ -221,28 +248,8 @@ def test_combined_loader_sequence_with_map_and_iterable(lengths): x, y = lengths loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))] dataloader = CombinedLoader(loaders, mode="max_size_cycle") - counter = 0 - for _ in dataloader: - counter += 1 - assert counter == max(x, y) - - -def test_combined_loader_sequence_max_size_cycle(): - """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders.""" - loaders = [ - torch.utils.data.DataLoader(range(10), batch_size=4), - torch.utils.data.DataLoader(range(20), batch_size=5), - ] - - combined_loader = CombinedLoader(loaders, "max_size_cycle") - - assert len(combined_loader) == max(len(v) for v in loaders) - - for idx, item in enumerate(combined_loader): - assert isinstance(item, Sequence) - assert len(item) == 2 - - assert idx == len(combined_loader) - 1 + seen = sum(1 for _ in dataloader) + assert seen == max(x, y) @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) @@ -340,22 +347,19 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_s ) with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"): len(dataloader) - assert len(dataloader.loaders["b"]) == 8 + assert len(dataloader.iterables["b"]) == 8 dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False) - assert len(dataloader.loaders["b"]) == 4 if replace_sampler_ddp else 8 + assert len(dataloader.iterables["b"]) == 4 if replace_sampler_ddp else 8 with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"): len(dataloader) @pytest.mark.parametrize("replace_sampler_ddp", [False, True]) -@pytest.mark.parametrize("is_min_size_mode", [False, True]) +@pytest.mark.parametrize("mode", ("min_size", "max_size_cycle", "sequential")) @pytest.mark.parametrize("use_combined_loader", [False, True]) -def test_combined_dataloader_for_training_with_ddp( - replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool -): +def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, use_combined_loader): """When providing a CombinedLoader as the training data, it should be correctly receive the distributed samplers.""" - mode = "min_size" if is_min_size_mode else "max_size_cycle" dim = 3 n1 = 8 n2 = 6 @@ -371,12 +375,13 @@ def test_combined_dataloader_for_training_with_ddp( accelerator="auto", devices="auto", replace_sampler_ddp=replace_sampler_ddp, - multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode, + multiple_trainloader_mode=mode, ) trainer._data_connector.attach_data( model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None ) - expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2) + fn = _supported_modes[mode]["fn"] + expected_length_before_ddp = fn([n1, n2]) expected_length_after_ddp = ( math.ceil(expected_length_before_ddp / trainer.num_devices) if replace_sampler_ddp diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 2e12ea1729..08b94a4763 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -73,10 +73,10 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b assert model.batch_size == new_batch_size if dm_bs == -1: # datamodule batch size takes precedence - assert trainer.train_dataloader.loaders.batch_size == new_batch_size + assert trainer.train_dataloader.iterables.batch_size == new_batch_size if dm_bs not in (-1, None): assert datamodule.batch_size == new_batch_size - assert trainer.train_dataloader.loaders.batch_size == new_batch_size + assert trainer.train_dataloader.iterables.batch_size == new_batch_size @pytest.mark.parametrize("trainer_fn", ["fit", "validate", "test", "predict"]) @@ -312,7 +312,7 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method): new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs) assert advance_mocked.call_count == max_trials - assert trainer.train_dataloader.loaders.batch_size == new_batch_size + assert trainer.train_dataloader.iterables.batch_size == new_batch_size assert trainer.val_dataloaders[0].batch_size == init_batch_size @@ -357,7 +357,7 @@ def test_batch_size_finder_callback(tmpdir, trainer_fn): loop = getattr(trainer, f"{trainer_fn}_loop") if trainer_fn == "fit": - expected_steps = trainer.train_dataloader.loaders.dataset.len // after_batch_size + expected_steps = trainer.train_dataloader.iterables.dataset.len // after_batch_size assert trainer.global_step == expected_steps * max_epochs assert trainer.current_epoch == max_epochs assert loop.epoch_loop.batch_progress.total.completed == expected_steps * max_epochs @@ -466,4 +466,4 @@ def test_dataloader_batch_size_updated_on_failure(_, tmpdir, scale_method, expec new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs) assert new_batch_size == model.batch_size assert new_batch_size == expected_batch_size - assert trainer.train_dataloader.loaders.batch_size == expected_batch_size + assert trainer.train_dataloader.iterables.batch_size == expected_batch_size