"sequential" mode for `CombinedLoader` (#16743)
This commit is contained in:
parent
a342410e25
commit
6c037a479f
|
@ -81,7 +81,7 @@ class Top1:
|
|||
class ProductionReadyModel(LitModule, ServableModule):
|
||||
def configure_payload(self):
|
||||
# 1: Access the train dataloader and load a single sample.
|
||||
image, _ = self.trainer.train_dataloader.loaders.dataset[0]
|
||||
image, _ = self.trainer.train_dataloader.iterables.dataset[0]
|
||||
|
||||
# 2: Convert the image into a PIL Image to bytes and encode it with base64
|
||||
pil_image = T.ToPILImage()(image)
|
||||
|
|
|
@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))
|
||||
|
||||
|
||||
- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
|
||||
|
||||
### Changed
|
||||
|
||||
- "Native" suffix removal ([#16490](https://github.com/Lightning-AI/lightning/pull/16490))
|
||||
|
@ -82,6 +84,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))
|
||||
|
||||
|
||||
- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables`([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
|
||||
|
||||
### Deprecated
|
||||
|
||||
-
|
||||
|
|
|
@ -245,7 +245,7 @@ class DataConnector:
|
|||
- Wrapping the dataloader based on strategy-specific logic
|
||||
"""
|
||||
if isinstance(dataloader, CombinedLoader):
|
||||
for i, dl in enumerate(dataloader._loaders_flattened):
|
||||
for i, dl in enumerate(dataloader._flattened):
|
||||
dataloader._update_index(self._prepare_dataloader(dl, shuffle=shuffle, mode=mode), i)
|
||||
return dataloader
|
||||
|
||||
|
@ -347,7 +347,7 @@ class DataConnector:
|
|||
|
||||
for loader in dataloaders:
|
||||
apply_to_collection(
|
||||
loader.loaders if isinstance(loader, CombinedLoader) else loader,
|
||||
loader.iterables if isinstance(loader, CombinedLoader) else loader,
|
||||
DataLoader,
|
||||
self._check_eval_shuffling,
|
||||
mode=mode,
|
||||
|
|
|
@ -11,9 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Iterable, Iterator, List, Literal, Optional, Sized, Type, TypeVar
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar
|
||||
|
||||
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader
|
||||
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter
|
||||
from typing_extensions import Self, TypedDict
|
||||
|
||||
from lightning.fabric.utilities.data import sized_len
|
||||
|
@ -73,6 +74,41 @@ class _MinSize(_ModeIterator[List]):
|
|||
return [next(it) for it in self.iterators]
|
||||
|
||||
|
||||
class _Sequential(_ModeIterator[Tuple[int, Any]]):
|
||||
def __init__(self, iterables: List[Iterable]) -> None:
|
||||
super().__init__(iterables)
|
||||
self._iterator_idx = 0 # what would be dataloader_idx
|
||||
self._idx = 0 # what would be batch_idx
|
||||
|
||||
def __next__(self) -> Tuple[int, Any]:
|
||||
n = len(self.iterators)
|
||||
if n == 0:
|
||||
raise StopIteration
|
||||
try:
|
||||
out = next(self.iterators[self._iterator_idx])
|
||||
index = self._idx
|
||||
self._idx += 1
|
||||
# the return is enumerated by default
|
||||
return index, out
|
||||
except StopIteration:
|
||||
self._iterator_idx += 1
|
||||
self._idx = 0
|
||||
if self._iterator_idx >= n:
|
||||
raise
|
||||
return self.__next__()
|
||||
|
||||
def __iter__(self) -> Self: # type: ignore[valid-type]
|
||||
super().__iter__()
|
||||
self._iterator_idx = 0
|
||||
self._idx = 0
|
||||
return self
|
||||
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self._iterator_idx = 0
|
||||
self._idx = 0
|
||||
|
||||
|
||||
class _CombinationMode(TypedDict):
|
||||
fn: Callable[[List[int]], int]
|
||||
iterator: Type[_ModeIterator]
|
||||
|
@ -81,9 +117,10 @@ class _CombinationMode(TypedDict):
|
|||
_supported_modes = {
|
||||
"min_size": _CombinationMode(fn=min, iterator=_MinSize),
|
||||
"max_size_cycle": _CombinationMode(fn=max, iterator=_MaxSizeCycle),
|
||||
"sequential": _CombinationMode(fn=sum, iterator=_Sequential),
|
||||
}
|
||||
|
||||
_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle"]
|
||||
_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle", "sequential"]
|
||||
|
||||
|
||||
class _CombinedDataset(Sized):
|
||||
|
@ -114,20 +151,23 @@ class _CombinedDataset(Sized):
|
|||
|
||||
|
||||
class CombinedLoader(Iterable):
|
||||
"""Combines different dataloaders and allows sampling in parallel.
|
||||
"""Combines different iterables under custom sampling modes.
|
||||
|
||||
Args:
|
||||
loaders: the loaders to sample from. Can be all kind of collection
|
||||
iterables: the loaders to sample from. Can be any kind of collection
|
||||
mode:
|
||||
* ``"min_size"``, which raises StopIteration after the shortest loader (the one with the lowest number of
|
||||
batches) is done.
|
||||
* ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is
|
||||
done, while cycling through the shorter loaders.
|
||||
* ``"min_size"``, which raises StopIteration after the shortest iterable (the one with the lowest number of
|
||||
items) is done.
|
||||
* ``"max_size_cycle"`` which raises StopIteration after the longest iterable (the one with most items) is
|
||||
done, while cycling through rest of the iterables.
|
||||
* ``"sequential"`` will consume ecah iterable sequentially, and returns a tuple with the associated index
|
||||
from each iterable.
|
||||
|
||||
Examples:
|
||||
>>> loaders = {'a': DataLoader(range(6), batch_size=4),
|
||||
... 'b': DataLoader(range(15), batch_size=5)}
|
||||
>>> combined_loader = CombinedLoader(loaders, 'max_size_cycle')
|
||||
>>> from torch.utils.data import DataLoader
|
||||
>>> iterables = {'a': DataLoader(range(6), batch_size=4),
|
||||
... 'b': DataLoader(range(15), batch_size=5)}
|
||||
>>> combined_loader = CombinedLoader(iterables, 'max_size_cycle')
|
||||
>>> len(combined_loader)
|
||||
3
|
||||
>>> for item in combined_loader:
|
||||
|
@ -135,26 +175,33 @@ class CombinedLoader(Iterable):
|
|||
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
|
||||
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
|
||||
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
|
||||
>>> combined_loader = CombinedLoader(loaders, 'min_size')
|
||||
>>> combined_loader = CombinedLoader(iterables, 'min_size')
|
||||
>>> len(combined_loader)
|
||||
2
|
||||
>>> for item in combined_loader:
|
||||
... print(item)
|
||||
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
|
||||
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
|
||||
>>> combined_loader = CombinedLoader(iterables, 'sequential')
|
||||
>>> len(combined_loader)
|
||||
5
|
||||
>>> for item in combined_loader:
|
||||
... print(*item)
|
||||
0 tensor([0, 1, 2, 3])
|
||||
1 tensor([4, 5])
|
||||
0 tensor([0, 1, 2, 3, 4])
|
||||
1 tensor([5, 6, 7, 8, 9])
|
||||
2 tensor([10, 11, 12, 13, 14])
|
||||
"""
|
||||
|
||||
def __init__(self, loaders: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
|
||||
def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:
|
||||
if mode not in _supported_modes:
|
||||
raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_supported_modes)}.")
|
||||
# TODO(carmocca): rename loaders to iterables
|
||||
self._loaders = loaders
|
||||
self._loaders_flattened, self._loaders_spec = _tree_flatten(loaders)
|
||||
self._iterables = iterables
|
||||
self._flattened, self._spec = _tree_flatten(iterables)
|
||||
|
||||
# TODO(carmocca): doing this might not be necessary
|
||||
datasets = _map_and_unflatten(
|
||||
lambda x: getattr(x, "dataset", None), self._loaders_flattened, self._loaders_spec
|
||||
)
|
||||
datasets = _map_and_unflatten(lambda x: getattr(x, "dataset", None), self._flattened, self._spec)
|
||||
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
|
||||
self.dataset = _CombinedDataset(datasets, mode)
|
||||
|
||||
|
@ -162,30 +209,30 @@ class CombinedLoader(Iterable):
|
|||
self._iterator: Optional[_ModeIterator] = None
|
||||
|
||||
@property
|
||||
def loaders(self) -> Any:
|
||||
"""Return the original collection of loaders."""
|
||||
return self._loaders
|
||||
def iterables(self) -> Any:
|
||||
"""Return the original collection of iterables."""
|
||||
return self._iterables
|
||||
|
||||
@property
|
||||
def sampler(self) -> Any:
|
||||
"""Return a collections of samplers extracted from loaders."""
|
||||
return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._loaders_flattened, self._loaders_spec)
|
||||
"""Return a collections of samplers extracted from iterables."""
|
||||
return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self._flattened, self._spec)
|
||||
|
||||
@property
|
||||
def batch_sampler(self) -> Any:
|
||||
"""Return a collections of batch samplers extracted from loaders."""
|
||||
return _map_and_unflatten(
|
||||
lambda x: getattr(x, "batch_sampler", None), self._loaders_flattened, self._loaders_spec
|
||||
)
|
||||
"""Return a collections of batch samplers extracted from iterables."""
|
||||
return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self._flattened, self._spec)
|
||||
|
||||
def __next__(self) -> Any:
|
||||
assert self._iterator is not None
|
||||
out = next(self._iterator)
|
||||
return tree_unflatten(out, self._loaders_spec)
|
||||
if isinstance(self._iterator, _Sequential):
|
||||
return out
|
||||
return tree_unflatten(out, self._spec)
|
||||
|
||||
def __iter__(self) -> Self: # type: ignore[valid-type]
|
||||
cls = _supported_modes[self._mode]["iterator"]
|
||||
iterator = cls(self._loaders_flattened)
|
||||
iterator = cls(self._flattened)
|
||||
iter(iterator)
|
||||
self._iterator = iterator
|
||||
return self
|
||||
|
@ -193,7 +240,7 @@ class CombinedLoader(Iterable):
|
|||
def __len__(self) -> int:
|
||||
"""Compute the number of batches."""
|
||||
lengths = []
|
||||
for dl in self._loaders_flattened:
|
||||
for dl in self._flattened:
|
||||
length = sized_len(dl)
|
||||
if length is None:
|
||||
raise NotImplementedError(f"`{type(dl).__name__}` does not define `__len__`")
|
||||
|
@ -205,16 +252,16 @@ class CombinedLoader(Iterable):
|
|||
if self._iterator is not None:
|
||||
self._iterator.reset()
|
||||
self._iterator = None
|
||||
for loader in self._loaders_flattened:
|
||||
_shutdown_workers_and_reset_iterator(loader)
|
||||
for iterable in self._flattened:
|
||||
_shutdown_workers_and_reset_iterator(iterable)
|
||||
|
||||
def _update_index(self, dataloader: Iterable, index: int) -> None:
|
||||
# mutation needs to be done using this method to avoid stale references
|
||||
self._loaders_flattened[index] = dataloader
|
||||
self._loaders = tree_unflatten(self._loaders_flattened, self._loaders_spec)
|
||||
self._flattened[index] = dataloader
|
||||
self._iterables = tree_unflatten(self._flattened, self._spec)
|
||||
|
||||
|
||||
def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None:
|
||||
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None:
|
||||
if hasattr(dataloader, "_iterator"):
|
||||
if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
|
||||
dataloader._iterator._shutdown_workers()
|
||||
|
|
|
@ -285,11 +285,8 @@ class Trainer:
|
|||
enable_model_summary: Whether to enable model summarization by default.
|
||||
Default: ``True``.
|
||||
|
||||
multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders.
|
||||
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
|
||||
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
|
||||
reload when reaching the minimum length of datasets.
|
||||
Default: ``"max_size_cycle"``.
|
||||
multiple_trainloader_mode: How to loop over the datasets when there are multiple iterables.
|
||||
See :class:`lightning.pytorch.trainer.supporters.CombinedLoader`.
|
||||
|
||||
inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
|
||||
evaluation (``validate``/``test``/``predict``).
|
||||
|
@ -1033,7 +1030,7 @@ class Trainer:
|
|||
mode=RunningStage.TRAINING,
|
||||
)
|
||||
loaders = (
|
||||
self.train_dataloader.loaders
|
||||
self.train_dataloader.iterables
|
||||
if isinstance(self.train_dataloader, CombinedLoader)
|
||||
else self.train_dataloader
|
||||
)
|
||||
|
@ -1044,7 +1041,7 @@ class Trainer:
|
|||
# add worker_init_fn for correct seeding in worker processes
|
||||
apply_to_collection(loaders, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)
|
||||
|
||||
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
|
||||
# wrap the sequence of train iterables to a CombinedLoader object for computing the num_training_batches
|
||||
if not isinstance(self.train_dataloader, CombinedLoader):
|
||||
self.train_dataloader = CombinedLoader(loaders, self._data_connector.multiple_trainloader_mode)
|
||||
|
||||
|
|
|
@ -365,7 +365,7 @@ def test_manual_poptorch_dataloader(tmpdir):
|
|||
|
||||
assert isinstance(trainer.strategy, IPUStrategy)
|
||||
assert trainer.strategy.training_opts is other_options
|
||||
dataloader = trainer.train_dataloader.loaders
|
||||
dataloader = trainer.train_dataloader.iterables
|
||||
assert dataloader is model.poptorch_dataloader # exact object, was not recreated
|
||||
# dataloader uses the options in the model, not the strategy
|
||||
assert dataloader.options is model_options
|
||||
|
@ -393,7 +393,7 @@ def test_manual_poptorch_opts(tmpdir):
|
|||
assert trainer.strategy.training_opts == training_opts
|
||||
assert trainer.strategy.inference_opts == inference_opts
|
||||
|
||||
dataloader = trainer.train_dataloader.loaders
|
||||
dataloader = trainer.train_dataloader.iterables
|
||||
assert isinstance(dataloader, poptorch.DataLoader)
|
||||
assert dataloader.options == training_opts
|
||||
assert trainer.num_devices > 1 # testing this only makes sense in a distributed setting
|
||||
|
@ -427,7 +427,7 @@ def test_manual_poptorch_opts_custom(tmpdir):
|
|||
val_dataloader = trainer.val_dataloaders[0]
|
||||
train_dataloader = trainer.train_dataloader
|
||||
assert isinstance(train_dataloader, CombinedLoader)
|
||||
train_dataloader = train_dataloader.loaders
|
||||
train_dataloader = train_dataloader.iterables
|
||||
assert isinstance(val_dataloader, poptorch.DataLoader)
|
||||
assert isinstance(train_dataloader, poptorch.DataLoader)
|
||||
assert train_dataloader.options.replication_factor == 2
|
||||
|
|
|
@ -116,7 +116,7 @@ class TestSpawnBoringModel(BoringModel):
|
|||
|
||||
def on_train_end(self):
|
||||
def _get_warning_msg():
|
||||
dl = self.trainer.train_dataloader.loaders
|
||||
dl = self.trainer.train_dataloader.iterables
|
||||
if hasattr(dl, "persistent_workers"):
|
||||
if self.num_workers == 0:
|
||||
warn_str = "Consider setting num_workers>0 and persistent_workers=True"
|
||||
|
@ -295,7 +295,7 @@ def test_dataloader_reinit_for_subclass():
|
|||
|
||||
class LoaderTestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
assert len(self.trainer.train_dataloader.loaders) == 10
|
||||
assert len(self.trainer.train_dataloader.iterables) == 10
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
|
|
|
@ -74,7 +74,7 @@ def test_overfit_batches_raises_warning_in_case_of_sequential_sampler(tmpdir):
|
|||
with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"):
|
||||
trainer.fit(model)
|
||||
|
||||
assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
|
||||
assert isinstance(trainer.train_dataloader.iterables.sampler, SequentialSampler)
|
||||
assert isinstance(trainer.val_dataloaders[0].sampler, SequentialSampler)
|
||||
|
||||
|
||||
|
@ -161,6 +161,6 @@ def test_distributed_sampler_with_overfit_batches():
|
|||
trainer.strategy._lightning_module = model
|
||||
trainer._data_connector.attach_dataloaders(model)
|
||||
trainer.reset_train_dataloader()
|
||||
train_sampler = trainer.train_dataloader.loaders.sampler
|
||||
train_sampler = trainer.train_dataloader.iterables.sampler
|
||||
assert isinstance(train_sampler, DistributedSampler)
|
||||
assert train_sampler.shuffle is False
|
||||
|
|
|
@ -160,7 +160,7 @@ def test_train_dataloader_passed_to_fit(tmpdir):
|
|||
fit_options = dict(train_dataloaders=train_loader)
|
||||
trainer.fit(model, **fit_options)
|
||||
assert trainer.num_training_batches == 2
|
||||
assert trainer.train_dataloader.loaders == train_loader
|
||||
assert trainer.train_dataloader.iterables == train_loader
|
||||
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
||||
|
@ -836,7 +836,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir):
|
|||
[("min_size", 16), ("max_size_cycle", 64)],
|
||||
)
|
||||
def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches):
|
||||
"""Integration test for multiple train loaders."""
|
||||
"""Integration test for multiple train iterables."""
|
||||
|
||||
class CustomBoringModel(BoringModel):
|
||||
def train_dataloader(self):
|
||||
|
@ -1178,12 +1178,12 @@ def test_dataloaders_reset_and_attach(tmpdir):
|
|||
|
||||
# 1st fit
|
||||
trainer.fit(model, train_dataloaders=dataloader_0, val_dataloaders=dataloader_1)
|
||||
assert trainer.train_dataloader.loaders.dataset is dataloader_0.dataset
|
||||
assert trainer.train_dataloader.iterables.dataset is dataloader_0.dataset
|
||||
assert trainer.val_dataloaders[0].dataset is dataloader_1.dataset
|
||||
# 2nd fit
|
||||
trainer.fit_loop.max_steps += 1
|
||||
trainer.fit(model, train_dataloaders=dataloader_2, val_dataloaders=dataloader_3)
|
||||
assert trainer.train_dataloader.loaders.dataset is dataloader_2.dataset
|
||||
assert trainer.train_dataloader.iterables.dataset is dataloader_2.dataset
|
||||
assert trainer.val_dataloaders[0].dataset is dataloader_3.dataset
|
||||
|
||||
# 1st validate
|
||||
|
@ -1316,7 +1316,7 @@ def test_request_dataloader(tmpdir):
|
|||
return DataLoaderWrapper(loader)
|
||||
|
||||
def on_train_batch_start(self, batch, batch_idx: int) -> None:
|
||||
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
|
||||
assert isinstance(self.trainer.train_dataloader.iterables, DataLoaderWrapper)
|
||||
self.on_train_batch_start_called = True
|
||||
|
||||
def val_dataloader(self):
|
||||
|
|
|
@ -18,6 +18,7 @@ from unittest import mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from torch.utils.data.dataset import Dataset, IterableDataset
|
||||
|
@ -26,7 +27,14 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
|||
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
||||
from lightning.pytorch.trainer.supporters import _CombinedDataset, CombinedLoader
|
||||
from lightning.pytorch.trainer.supporters import (
|
||||
_CombinedDataset,
|
||||
_MaxSizeCycle,
|
||||
_MinSize,
|
||||
_Sequential,
|
||||
_supported_modes,
|
||||
CombinedLoader,
|
||||
)
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
|
@ -77,32 +85,85 @@ def test_combined_dataset_no_length():
|
|||
len(cd)
|
||||
|
||||
|
||||
def test_combined_loader_dict_min_size():
|
||||
"""Test `CombinedLoaderIterator` given mapping loaders."""
|
||||
loaders = {
|
||||
def test_combined_loader_modes():
|
||||
"""Test `CombinedLoaderIterator` given mapping iterables."""
|
||||
iterables = {
|
||||
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
|
||||
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
|
||||
}
|
||||
cl = CombinedLoader(loaders, "min_size")
|
||||
lengths = [len(v) for v in iterables.values()]
|
||||
|
||||
for idx, item in enumerate(cl):
|
||||
assert isinstance(item, dict)
|
||||
assert len(item) == 2
|
||||
assert "a" in item and "b" in item
|
||||
assert idx == min(len(loaders["a"]), len(loaders["b"])) - 1
|
||||
|
||||
loaders = [
|
||||
torch.utils.data.DataLoader(range(10), batch_size=4),
|
||||
torch.utils.data.DataLoader(range(20), batch_size=5),
|
||||
]
|
||||
combined_loader = CombinedLoader(loaders, "min_size")
|
||||
|
||||
assert len(combined_loader) == min(len(v) for v in loaders)
|
||||
# min_size with dict
|
||||
min_len = min(lengths)
|
||||
combined_loader = CombinedLoader(iterables, "min_size")
|
||||
assert combined_loader._iterator is None
|
||||
assert len(combined_loader) == min_len
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(combined_loader._iterator, _MinSize)
|
||||
assert isinstance(item, dict)
|
||||
assert list(item) == ["a", "b"]
|
||||
assert idx == min_len - 1
|
||||
assert idx == len(combined_loader) - 1
|
||||
|
||||
# max_size_cycle with dict
|
||||
max_len = max(lengths)
|
||||
combined_loader = CombinedLoader(iterables, "max_size_cycle")
|
||||
assert combined_loader._iterator is None
|
||||
assert len(combined_loader) == max_len
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
|
||||
assert isinstance(item, dict)
|
||||
assert list(item) == ["a", "b"]
|
||||
assert idx == max_len - 1
|
||||
assert idx == len(combined_loader) - 1
|
||||
|
||||
# sequential with dict
|
||||
sum_len = sum(lengths)
|
||||
combined_loader = CombinedLoader(iterables, "sequential")
|
||||
assert combined_loader._iterator is None
|
||||
assert len(combined_loader) == sum_len
|
||||
for total_idx, (idx, item) in enumerate(combined_loader):
|
||||
assert isinstance(combined_loader._iterator, _Sequential)
|
||||
assert isinstance(idx, int)
|
||||
assert isinstance(item, Tensor)
|
||||
assert idx == lengths[-1] - 1
|
||||
assert total_idx == sum_len - 1
|
||||
assert total_idx == len(combined_loader) - 1
|
||||
|
||||
iterables = list(iterables.values())
|
||||
|
||||
# min_size with list
|
||||
combined_loader = CombinedLoader(iterables, "min_size")
|
||||
assert len(combined_loader) == min_len
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(combined_loader._iterator, _MinSize)
|
||||
assert isinstance(item, list)
|
||||
assert len(item) == 2
|
||||
assert idx == min_len - 1
|
||||
assert idx == len(combined_loader) - 1
|
||||
|
||||
# max_size_cycle with list
|
||||
combined_loader = CombinedLoader(iterables, "max_size_cycle")
|
||||
assert len(combined_loader) == max_len
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(combined_loader._iterator, _MaxSizeCycle)
|
||||
assert isinstance(item, list)
|
||||
assert len(item) == 2
|
||||
assert idx == max_len - 1
|
||||
assert idx == len(combined_loader) - 1
|
||||
|
||||
# sequential with list
|
||||
combined_loader = CombinedLoader(iterables, "sequential")
|
||||
assert combined_loader._iterator is None
|
||||
assert len(combined_loader) == sum_len
|
||||
for total_idx, (idx, item) in enumerate(combined_loader):
|
||||
assert isinstance(combined_loader._iterator, _Sequential)
|
||||
assert isinstance(idx, int)
|
||||
assert isinstance(item, Tensor)
|
||||
assert idx == lengths[-1] - 1
|
||||
assert total_idx == sum_len - 1
|
||||
assert total_idx == len(combined_loader) - 1
|
||||
|
||||
|
||||
def test_combined_loader_raises():
|
||||
with pytest.raises(ValueError, match="Unsupported mode 'testtt'"):
|
||||
|
@ -113,43 +174,6 @@ def test_combined_loader_raises():
|
|||
len(combined_loader)
|
||||
|
||||
|
||||
def test_combined_loader_dict_max_size_cycle():
|
||||
"""Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders."""
|
||||
loaders = {
|
||||
"a": torch.utils.data.DataLoader(range(10), batch_size=4),
|
||||
"b": torch.utils.data.DataLoader(range(20), batch_size=5),
|
||||
}
|
||||
|
||||
combined_loader = CombinedLoader(loaders, "max_size_cycle")
|
||||
|
||||
assert len(combined_loader) == max(len(v) for v in loaders.values())
|
||||
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(item, dict)
|
||||
assert len(item) == 2
|
||||
assert "a" in item and "b" in item
|
||||
|
||||
assert idx == len(combined_loader) - 1
|
||||
|
||||
|
||||
def test_combined_loader_sequence_min_size():
|
||||
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
|
||||
loaders = [
|
||||
torch.utils.data.DataLoader(range(10), batch_size=4),
|
||||
torch.utils.data.DataLoader(range(20), batch_size=5),
|
||||
]
|
||||
|
||||
combined_loader = CombinedLoader(loaders, "min_size")
|
||||
|
||||
assert len(combined_loader) == min(len(v) for v in loaders)
|
||||
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(item, Sequence)
|
||||
assert len(item) == 2
|
||||
|
||||
assert idx == len(combined_loader) - 1
|
||||
|
||||
|
||||
class TestIterableDataset(IterableDataset):
|
||||
def __init__(self, size: int = 10):
|
||||
self.size = size
|
||||
|
@ -163,10 +187,10 @@ class TestIterableDataset(IterableDataset):
|
|||
return next(self.sampler_iter)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"])
|
||||
@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "sequential"])
|
||||
@pytest.mark.parametrize("use_multiple_dataloaders", [False, True])
|
||||
def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders):
|
||||
"""Test `CombinedLoader` of mode 'min_size' given sequence loaders."""
|
||||
"""Test `CombinedLoader` of mode 'min_size' given sequence iterables."""
|
||||
if use_multiple_dataloaders:
|
||||
loaders = [
|
||||
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
|
||||
|
@ -176,11 +200,9 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
|
|||
loaders = [
|
||||
torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2),
|
||||
]
|
||||
|
||||
combined_loader = CombinedLoader(loaders, mode)
|
||||
|
||||
has_break = False
|
||||
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(item, Sequence)
|
||||
assert len(item) == 2 if use_multiple_dataloaders else 1
|
||||
|
@ -190,8 +212,13 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
|
|||
|
||||
if mode == "max_size_cycle":
|
||||
assert all(combined_loader._iterator._consumed) == (not has_break)
|
||||
expected = (10 if mode == "max_size_cycle" else 5) if use_multiple_dataloaders else 5
|
||||
assert (expected - 1) == idx, (mode, use_multiple_dataloaders)
|
||||
expected = 5
|
||||
if use_multiple_dataloaders:
|
||||
if mode == "max_size_cycle":
|
||||
expected = 10
|
||||
elif mode == "sequential":
|
||||
expected = 15
|
||||
assert idx == expected - 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
|
||||
|
@ -221,28 +248,8 @@ def test_combined_loader_sequence_with_map_and_iterable(lengths):
|
|||
x, y = lengths
|
||||
loaders = [DataLoader(MyIterableDataset(x)), DataLoader(MyMapDataset(y))]
|
||||
dataloader = CombinedLoader(loaders, mode="max_size_cycle")
|
||||
counter = 0
|
||||
for _ in dataloader:
|
||||
counter += 1
|
||||
assert counter == max(x, y)
|
||||
|
||||
|
||||
def test_combined_loader_sequence_max_size_cycle():
|
||||
"""Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders."""
|
||||
loaders = [
|
||||
torch.utils.data.DataLoader(range(10), batch_size=4),
|
||||
torch.utils.data.DataLoader(range(20), batch_size=5),
|
||||
]
|
||||
|
||||
combined_loader = CombinedLoader(loaders, "max_size_cycle")
|
||||
|
||||
assert len(combined_loader) == max(len(v) for v in loaders)
|
||||
|
||||
for idx, item in enumerate(combined_loader):
|
||||
assert isinstance(item, Sequence)
|
||||
assert len(item) == 2
|
||||
|
||||
assert idx == len(combined_loader) - 1
|
||||
seen = sum(1 for _ in dataloader)
|
||||
assert seen == max(x, y)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
|
||||
|
@ -340,22 +347,19 @@ def test_combined_data_loader_with_max_size_cycle_and_ddp(accelerator, replace_s
|
|||
)
|
||||
with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"):
|
||||
len(dataloader)
|
||||
assert len(dataloader.loaders["b"]) == 8
|
||||
assert len(dataloader.iterables["b"]) == 8
|
||||
dataloader = trainer._data_connector._prepare_dataloader(dataloader, shuffle=False)
|
||||
assert len(dataloader.loaders["b"]) == 4 if replace_sampler_ddp else 8
|
||||
assert len(dataloader.iterables["b"]) == 4 if replace_sampler_ddp else 8
|
||||
with pytest.raises(NotImplementedError, match="DataLoader` does not define `__len__"):
|
||||
len(dataloader)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("replace_sampler_ddp", [False, True])
|
||||
@pytest.mark.parametrize("is_min_size_mode", [False, True])
|
||||
@pytest.mark.parametrize("mode", ("min_size", "max_size_cycle", "sequential"))
|
||||
@pytest.mark.parametrize("use_combined_loader", [False, True])
|
||||
def test_combined_dataloader_for_training_with_ddp(
|
||||
replace_sampler_ddp: bool, is_min_size_mode: bool, use_combined_loader: bool
|
||||
):
|
||||
def test_combined_dataloader_for_training_with_ddp(replace_sampler_ddp, mode, use_combined_loader):
|
||||
"""When providing a CombinedLoader as the training data, it should be correctly receive the distributed
|
||||
samplers."""
|
||||
mode = "min_size" if is_min_size_mode else "max_size_cycle"
|
||||
dim = 3
|
||||
n1 = 8
|
||||
n2 = 6
|
||||
|
@ -371,12 +375,13 @@ def test_combined_dataloader_for_training_with_ddp(
|
|||
accelerator="auto",
|
||||
devices="auto",
|
||||
replace_sampler_ddp=replace_sampler_ddp,
|
||||
multiple_trainloader_mode="max_size_cycle" if use_combined_loader else mode,
|
||||
multiple_trainloader_mode=mode,
|
||||
)
|
||||
trainer._data_connector.attach_data(
|
||||
model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
|
||||
)
|
||||
expected_length_before_ddp = min(n1, n2) if is_min_size_mode else max(n1, n2)
|
||||
fn = _supported_modes[mode]["fn"]
|
||||
expected_length_before_ddp = fn([n1, n2])
|
||||
expected_length_after_ddp = (
|
||||
math.ceil(expected_length_before_ddp / trainer.num_devices)
|
||||
if replace_sampler_ddp
|
||||
|
|
|
@ -73,10 +73,10 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b
|
|||
assert model.batch_size == new_batch_size
|
||||
if dm_bs == -1:
|
||||
# datamodule batch size takes precedence
|
||||
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
|
||||
assert trainer.train_dataloader.iterables.batch_size == new_batch_size
|
||||
if dm_bs not in (-1, None):
|
||||
assert datamodule.batch_size == new_batch_size
|
||||
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
|
||||
assert trainer.train_dataloader.iterables.batch_size == new_batch_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer_fn", ["fit", "validate", "test", "predict"])
|
||||
|
@ -312,7 +312,7 @@ def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
|
|||
new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs)
|
||||
assert advance_mocked.call_count == max_trials
|
||||
|
||||
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
|
||||
assert trainer.train_dataloader.iterables.batch_size == new_batch_size
|
||||
assert trainer.val_dataloaders[0].batch_size == init_batch_size
|
||||
|
||||
|
||||
|
@ -357,7 +357,7 @@ def test_batch_size_finder_callback(tmpdir, trainer_fn):
|
|||
loop = getattr(trainer, f"{trainer_fn}_loop")
|
||||
|
||||
if trainer_fn == "fit":
|
||||
expected_steps = trainer.train_dataloader.loaders.dataset.len // after_batch_size
|
||||
expected_steps = trainer.train_dataloader.iterables.dataset.len // after_batch_size
|
||||
assert trainer.global_step == expected_steps * max_epochs
|
||||
assert trainer.current_epoch == max_epochs
|
||||
assert loop.epoch_loop.batch_progress.total.completed == expected_steps * max_epochs
|
||||
|
@ -466,4 +466,4 @@ def test_dataloader_batch_size_updated_on_failure(_, tmpdir, scale_method, expec
|
|||
new_batch_size = tuner.scale_batch_size(model, **scale_batch_size_kwargs)
|
||||
assert new_batch_size == model.batch_size
|
||||
assert new_batch_size == expected_batch_size
|
||||
assert trainer.train_dataloader.loaders.batch_size == expected_batch_size
|
||||
assert trainer.train_dataloader.iterables.batch_size == expected_batch_size
|
||||
|
|
Loading…
Reference in New Issue