Move the `CombinedLoader` to an utility file (#16819)

This commit is contained in:
Carlos Mocholí 2023-02-20 18:06:35 +01:00 committed by GitHub
parent d807c003a7
commit b30a43f783
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 34 additions and 20 deletions

View File

@ -49,8 +49,8 @@ There are a few ways to pass multiple Datasets to Lightning:
2. In the training loop, you can pass multiple DataLoaders as a dict or list/tuple, and Lightning will
automatically combine the batches from different DataLoaders.
3. In the validation, test, or prediction, you have the option to return multiple DataLoaders as list/tuple, which Lightning will call sequentially
or combine the DataLoaders using :class:`~pytorch_lightning.trainer.supporters.CombinedLoader`, which Lightning will
automatically combine the batches from different DataLoaders.
or combine the DataLoaders using :class:`~pytorch_lightning.utilities.CombinedLoader`, which is what Lightning uses
under the hood.
Using LightningDataModule
@ -174,11 +174,11 @@ Furthermore, Lightning also supports nested lists and dicts (or a combination).
batch_c = batch_c_d["c"]
batch_d = batch_c_d["d"]
Alternatively, you can also pass in a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` containing multiple DataLoaders.
Alternatively, you can also pass in a :class:`~pytorch_lightning.utilities.CombinedLoader` containing multiple DataLoaders.
.. testcode::
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import CombinedLoader
def train_dataloader(self):
@ -222,18 +222,18 @@ Refer to the following for more details for the default sequential option:
...
Evaluation DataLoaders are iterated over sequentially. If you want to iterate over them in parallel, PyTorch Lightning provides a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` object which supports collections of DataLoaders such as list, tuple, or dictionary. The DataLoaders can be accessed using in the same way as the provided structure:
Evaluation DataLoaders are iterated over sequentially. The above is equivalent to:
.. testcode::
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import CombinedLoader
def val_dataloader(self):
loader_a = DataLoader()
loader_b = DataLoader()
loaders = {"a": loader_a, "b": loader_b}
combined_loaders = CombinedLoader(loaders, mode="max_size_cycle")
combined_loaders = CombinedLoader(loaders, mode="sequential")
return combined_loaders
@ -279,12 +279,12 @@ In the case that you require access to the DataLoader or Dataset objects, DataLo
# extract metadata, etc. from the dataset:
...
If you are using a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` object which allows you to fetch batches from a collection of DataLoaders
If you are using a :class:`~pytorch_lightning.utilities.CombinedLoader` object which allows you to fetch batches from a collection of DataLoaders
simultaneously which supports collections of DataLoader such as list, tuple, or dictionary. The DataLoaders can be accessed using the same collection structure:
.. code-block:: python
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import CombinedLoader
test_dl1 = ...
test_dl2 = ...
@ -292,14 +292,14 @@ simultaneously which supports collections of DataLoader such as list, tuple, or
# If you provided a list of DataLoaders:
combined_loader = CombinedLoader([test_dl1, test_dl2])
list_of_loaders = combined_loader.loaders
list_of_loaders = combined_loader.iterables
test_dl1 = list_of_loaders.loaders[0]
# If you provided dictionary of DataLoaders:
combined_loader = CombinedLoader({"dl1": test_dl1, "dl2": test_dl2})
dictionary_of_loaders = combined_loader.loaders
dictionary_of_loaders = combined_loader.iterables
test_dl1 = dictionary_of_loaders["dl1"]
--------------

View File

@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables` ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
- Moved the `CombinedLoader` class from `lightning.pytorch.trainer.supporters` to `lightning.pytorch.combined_loader` ([#16819](https://github.com/Lightning-AI/lightning/pull/16819))
- The top-level loops now own the data sources and combined dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))

View File

@ -30,7 +30,7 @@ from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.exceptions import SIGTERMException
from lightning.pytorch.utilities.model_helpers import is_overridden

View File

@ -17,7 +17,11 @@ from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union
from torch.utils.data.dataloader import DataLoader
from lightning.fabric.utilities.data import has_len
from lightning.pytorch.trainer.supporters import _Sequential, _shutdown_workers_and_reset_iterator, CombinedLoader
from lightning.pytorch.utilities.combined_loader import (
_Sequential,
_shutdown_workers_and_reset_iterator,
CombinedLoader,
)
from lightning.pytorch.utilities.exceptions import MisconfigurationException

View File

@ -25,7 +25,7 @@ from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
from lightning.pytorch.utilities.model_helpers import is_overridden

View File

@ -16,7 +16,7 @@ from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import _PREDICT_OUTPUT

View File

@ -28,7 +28,7 @@ from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler
from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader, has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden

View File

@ -18,6 +18,7 @@ import numpy
from lightning.fabric.utilities import LightningEnum # noqa: F401
from lightning.fabric.utilities import move_data_to_device # noqa: F401
from lightning.pytorch.accelerators.hpu import _HPU_AVAILABLE # noqa: F401
from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401
from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401
from lightning.pytorch.utilities.grads import grad_norm # noqa: F401
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCHVISION_AVAILABLE # noqa: F401

View File

@ -22,7 +22,7 @@ from lightning.pytorch import LightningDataModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher, _PrefetchDataFetcher
from lightning.pytorch.profilers import SimpleProfiler
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.types import STEP_OUTPUT
from tests_pytorch.helpers.runif import RunIf

View File

@ -29,7 +29,7 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import _update_dataloader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.runif import RunIf

View File

@ -34,7 +34,7 @@ from lightning.pytorch.demos.boring_classes import (
)
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader

View File

@ -25,7 +25,13 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
from lightning.pytorch.trainer.supporters import _MaxSizeCycle, _MinSize, _Sequential, _supported_modes, CombinedLoader
from lightning.pytorch.utilities.combined_loader import (
_MaxSizeCycle,
_MinSize,
_Sequential,
_supported_modes,
CombinedLoader,
)
from tests_pytorch.helpers.runif import RunIf