Move the `CombinedLoader` to an utility file (#16819)
This commit is contained in:
parent
d807c003a7
commit
b30a43f783
|
@ -49,8 +49,8 @@ There are a few ways to pass multiple Datasets to Lightning:
|
|||
2. In the training loop, you can pass multiple DataLoaders as a dict or list/tuple, and Lightning will
|
||||
automatically combine the batches from different DataLoaders.
|
||||
3. In the validation, test, or prediction, you have the option to return multiple DataLoaders as list/tuple, which Lightning will call sequentially
|
||||
or combine the DataLoaders using :class:`~pytorch_lightning.trainer.supporters.CombinedLoader`, which Lightning will
|
||||
automatically combine the batches from different DataLoaders.
|
||||
or combine the DataLoaders using :class:`~pytorch_lightning.utilities.CombinedLoader`, which is what Lightning uses
|
||||
under the hood.
|
||||
|
||||
|
||||
Using LightningDataModule
|
||||
|
@ -174,11 +174,11 @@ Furthermore, Lightning also supports nested lists and dicts (or a combination).
|
|||
batch_c = batch_c_d["c"]
|
||||
batch_d = batch_c_d["d"]
|
||||
|
||||
Alternatively, you can also pass in a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` containing multiple DataLoaders.
|
||||
Alternatively, you can also pass in a :class:`~pytorch_lightning.utilities.CombinedLoader` containing multiple DataLoaders.
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities import CombinedLoader
|
||||
|
||||
|
||||
def train_dataloader(self):
|
||||
|
@ -222,18 +222,18 @@ Refer to the following for more details for the default sequential option:
|
|||
...
|
||||
|
||||
|
||||
Evaluation DataLoaders are iterated over sequentially. If you want to iterate over them in parallel, PyTorch Lightning provides a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` object which supports collections of DataLoaders such as list, tuple, or dictionary. The DataLoaders can be accessed using in the same way as the provided structure:
|
||||
Evaluation DataLoaders are iterated over sequentially. The above is equivalent to:
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities import CombinedLoader
|
||||
|
||||
|
||||
def val_dataloader(self):
|
||||
loader_a = DataLoader()
|
||||
loader_b = DataLoader()
|
||||
loaders = {"a": loader_a, "b": loader_b}
|
||||
combined_loaders = CombinedLoader(loaders, mode="max_size_cycle")
|
||||
combined_loaders = CombinedLoader(loaders, mode="sequential")
|
||||
return combined_loaders
|
||||
|
||||
|
||||
|
@ -279,12 +279,12 @@ In the case that you require access to the DataLoader or Dataset objects, DataLo
|
|||
# extract metadata, etc. from the dataset:
|
||||
...
|
||||
|
||||
If you are using a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` object which allows you to fetch batches from a collection of DataLoaders
|
||||
If you are using a :class:`~pytorch_lightning.utilities.CombinedLoader` object which allows you to fetch batches from a collection of DataLoaders
|
||||
simultaneously which supports collections of DataLoader such as list, tuple, or dictionary. The DataLoaders can be accessed using the same collection structure:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities import CombinedLoader
|
||||
|
||||
test_dl1 = ...
|
||||
test_dl2 = ...
|
||||
|
@ -292,14 +292,14 @@ simultaneously which supports collections of DataLoader such as list, tuple, or
|
|||
# If you provided a list of DataLoaders:
|
||||
|
||||
combined_loader = CombinedLoader([test_dl1, test_dl2])
|
||||
list_of_loaders = combined_loader.loaders
|
||||
list_of_loaders = combined_loader.iterables
|
||||
test_dl1 = list_of_loaders.loaders[0]
|
||||
|
||||
|
||||
# If you provided dictionary of DataLoaders:
|
||||
|
||||
combined_loader = CombinedLoader({"dl1": test_dl1, "dl2": test_dl2})
|
||||
dictionary_of_loaders = combined_loader.loaders
|
||||
dictionary_of_loaders = combined_loader.iterables
|
||||
test_dl1 = dictionary_of_loaders["dl1"]
|
||||
|
||||
--------------
|
||||
|
|
|
@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables` ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
|
||||
|
||||
|
||||
- Moved the `CombinedLoader` class from `lightning.pytorch.trainer.supporters` to `lightning.pytorch.combined_loader` ([#16819](https://github.com/Lightning-AI/lightning/pull/16819))
|
||||
|
||||
|
||||
- The top-level loops now own the data sources and combined dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726))
|
||||
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ from lightning.pytorch.trainer import call
|
|||
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
|
||||
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
|
||||
from lightning.pytorch.utilities.exceptions import SIGTERMException
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
|
||||
|
|
|
@ -17,7 +17,11 @@ from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union
|
|||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from lightning.fabric.utilities.data import has_len
|
||||
from lightning.pytorch.trainer.supporters import _Sequential, _shutdown_workers_and_reset_iterator, CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import (
|
||||
_Sequential,
|
||||
_shutdown_workers_and_reset_iterator,
|
||||
CombinedLoader,
|
||||
)
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from lightning.pytorch.trainer import call
|
|||
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
|
||||
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from lightning.pytorch.trainer.states import RunningStage
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
||||
from lightning.pytorch.utilities.data import has_len_all_ranks
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
|
|
|
@ -16,7 +16,7 @@ from lightning.pytorch.strategies import DDPSpawnStrategy
|
|||
from lightning.pytorch.trainer import call
|
||||
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
|
||||
from lightning.pytorch.trainer.states import RunningStage
|
||||
from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.types import _PREDICT_OUTPUT
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler
|
|||
from lightning.pytorch.strategies import DDPSpawnStrategy
|
||||
from lightning.pytorch.trainer import call
|
||||
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
||||
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader, has_len_all_ranks
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy
|
|||
from lightning.fabric.utilities import LightningEnum # noqa: F401
|
||||
from lightning.fabric.utilities import move_data_to_device # noqa: F401
|
||||
from lightning.pytorch.accelerators.hpu import _HPU_AVAILABLE # noqa: F401
|
||||
from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401
|
||||
from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401
|
||||
from lightning.pytorch.utilities.grads import grad_norm # noqa: F401
|
||||
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCHVISION_AVAILABLE # noqa: F401
|
||||
|
|
|
@ -22,7 +22,7 @@ from lightning.pytorch import LightningDataModule, Trainer
|
|||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
||||
from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher, _PrefetchDataFetcher
|
||||
from lightning.pytorch.profilers import SimpleProfiler
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
|
|
@ -29,7 +29,7 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
|
|||
from lightning.pytorch.strategies import DDPSpawnStrategy
|
||||
from lightning.pytorch.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
|
||||
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
||||
from lightning.pytorch.utilities.data import _update_dataloader
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
|
|
@ -34,7 +34,7 @@ from lightning.pytorch.demos.boring_classes import (
|
|||
)
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
from lightning.pytorch.trainer.states import RunningStage
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import CombinedLoader
|
||||
from lightning.pytorch.utilities.data import has_len_all_ranks
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader
|
||||
|
|
|
@ -25,7 +25,13 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler
|
|||
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
||||
from lightning.pytorch.trainer.supporters import _MaxSizeCycle, _MinSize, _Sequential, _supported_modes, CombinedLoader
|
||||
from lightning.pytorch.utilities.combined_loader import (
|
||||
_MaxSizeCycle,
|
||||
_MinSize,
|
||||
_Sequential,
|
||||
_supported_modes,
|
||||
CombinedLoader,
|
||||
)
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
Loading…
Reference in New Issue