From b30a43f7831d3f826a16ee619874509479b8eff3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 20 Feb 2023 18:06:35 +0100 Subject: [PATCH] Move the `CombinedLoader` to an utility file (#16819) --- docs/source-pytorch/guides/data.rst | 22 +++++++++---------- src/lightning/pytorch/CHANGELOG.md | 3 +++ .../pytorch/loops/evaluation_loop.py | 2 +- src/lightning/pytorch/loops/fetchers.py | 6 ++++- src/lightning/pytorch/loops/fit_loop.py | 2 +- .../pytorch/loops/prediction_loop.py | 2 +- .../trainer/connectors/data_connector.py | 2 +- src/lightning/pytorch/utilities/__init__.py | 1 + .../combined_loader.py} | 0 tests/tests_pytorch/loops/test_fetchers.py | 2 +- .../trainer/connectors/test_data_connector.py | 2 +- .../tests_pytorch/trainer/test_dataloaders.py | 2 +- .../test_combined_loader.py} | 8 ++++++- 13 files changed, 34 insertions(+), 20 deletions(-) rename src/lightning/pytorch/{trainer/supporters.py => utilities/combined_loader.py} (100%) rename tests/tests_pytorch/{trainer/test_supporters.py => utilities/test_combined_loader.py} (99%) diff --git a/docs/source-pytorch/guides/data.rst b/docs/source-pytorch/guides/data.rst index 9cf9b8b8d5..bfcb11dafc 100644 --- a/docs/source-pytorch/guides/data.rst +++ b/docs/source-pytorch/guides/data.rst @@ -49,8 +49,8 @@ There are a few ways to pass multiple Datasets to Lightning: 2. In the training loop, you can pass multiple DataLoaders as a dict or list/tuple, and Lightning will automatically combine the batches from different DataLoaders. 3. In the validation, test, or prediction, you have the option to return multiple DataLoaders as list/tuple, which Lightning will call sequentially - or combine the DataLoaders using :class:`~pytorch_lightning.trainer.supporters.CombinedLoader`, which Lightning will - automatically combine the batches from different DataLoaders. + or combine the DataLoaders using :class:`~pytorch_lightning.utilities.CombinedLoader`, which is what Lightning uses + under the hood. Using LightningDataModule @@ -174,11 +174,11 @@ Furthermore, Lightning also supports nested lists and dicts (or a combination). batch_c = batch_c_d["c"] batch_d = batch_c_d["d"] -Alternatively, you can also pass in a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` containing multiple DataLoaders. +Alternatively, you can also pass in a :class:`~pytorch_lightning.utilities.CombinedLoader` containing multiple DataLoaders. .. testcode:: - from pytorch_lightning.trainer.supporters import CombinedLoader + from pytorch_lightning.utilities import CombinedLoader def train_dataloader(self): @@ -222,18 +222,18 @@ Refer to the following for more details for the default sequential option: ... -Evaluation DataLoaders are iterated over sequentially. If you want to iterate over them in parallel, PyTorch Lightning provides a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` object which supports collections of DataLoaders such as list, tuple, or dictionary. The DataLoaders can be accessed using in the same way as the provided structure: +Evaluation DataLoaders are iterated over sequentially. The above is equivalent to: .. testcode:: - from pytorch_lightning.trainer.supporters import CombinedLoader + from pytorch_lightning.utilities import CombinedLoader def val_dataloader(self): loader_a = DataLoader() loader_b = DataLoader() loaders = {"a": loader_a, "b": loader_b} - combined_loaders = CombinedLoader(loaders, mode="max_size_cycle") + combined_loaders = CombinedLoader(loaders, mode="sequential") return combined_loaders @@ -279,12 +279,12 @@ In the case that you require access to the DataLoader or Dataset objects, DataLo # extract metadata, etc. from the dataset: ... -If you are using a :class:`~pytorch_lightning.trainer.supporters.CombinedLoader` object which allows you to fetch batches from a collection of DataLoaders +If you are using a :class:`~pytorch_lightning.utilities.CombinedLoader` object which allows you to fetch batches from a collection of DataLoaders simultaneously which supports collections of DataLoader such as list, tuple, or dictionary. The DataLoaders can be accessed using the same collection structure: .. code-block:: python - from pytorch_lightning.trainer.supporters import CombinedLoader + from pytorch_lightning.utilities import CombinedLoader test_dl1 = ... test_dl2 = ... @@ -292,14 +292,14 @@ simultaneously which supports collections of DataLoader such as list, tuple, or # If you provided a list of DataLoaders: combined_loader = CombinedLoader([test_dl1, test_dl2]) - list_of_loaders = combined_loader.loaders + list_of_loaders = combined_loader.iterables test_dl1 = list_of_loaders.loaders[0] # If you provided dictionary of DataLoaders: combined_loader = CombinedLoader({"dl1": test_dl1, "dl2": test_dl2}) - dictionary_of_loaders = combined_loader.loaders + dictionary_of_loaders = combined_loader.iterables test_dl1 = dictionary_of_loaders["dl1"] -------------- diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 69efb4923c..502d92bc38 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables` ([#16743](https://github.com/Lightning-AI/lightning/pull/16743)) +- Moved the `CombinedLoader` class from `lightning.pytorch.trainer.supporters` to `lightning.pytorch.combined_loader` ([#16819](https://github.com/Lightning-AI/lightning/pull/16819)) + + - The top-level loops now own the data sources and combined dataloaders ([#16726](https://github.com/Lightning-AI/lightning/pull/16726)) diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index d821471e3c..630bff5e26 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -30,7 +30,7 @@ from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from lightning.pytorch.trainer.states import TrainerFn -from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader +from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader from lightning.pytorch.utilities.exceptions import SIGTERMException from lightning.pytorch.utilities.model_helpers import is_overridden diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index ee29e7b69c..5c31fcfc1a 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -17,7 +17,11 @@ from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union from torch.utils.data.dataloader import DataLoader from lightning.fabric.utilities.data import has_len -from lightning.pytorch.trainer.supporters import _Sequential, _shutdown_workers_and_reset_iterator, CombinedLoader +from lightning.pytorch.utilities.combined_loader import ( + _Sequential, + _shutdown_workers_and_reset_iterator, + CombinedLoader, +) from lightning.pytorch.utilities.exceptions import MisconfigurationException diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 4c74d88f90..4f930e7df2 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -25,7 +25,7 @@ from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection from lightning.pytorch.trainer.states import RunningStage -from lightning.pytorch.trainer.supporters import CombinedLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException, SIGTERMException from lightning.pytorch.utilities.model_helpers import is_overridden diff --git a/src/lightning/pytorch/loops/prediction_loop.py b/src/lightning/pytorch/loops/prediction_loop.py index 4454a79e1a..4ea3000135 100644 --- a/src/lightning/pytorch/loops/prediction_loop.py +++ b/src/lightning/pytorch/loops/prediction_loop.py @@ -16,7 +16,7 @@ from lightning.pytorch.strategies import DDPSpawnStrategy from lightning.pytorch.trainer import call from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource from lightning.pytorch.trainer.states import RunningStage -from lightning.pytorch.trainer.supporters import _Sequential, CombinedLoader +from lightning.pytorch.utilities.combined_loader import _Sequential, CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import _PREDICT_OUTPUT diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 691475560b..10f6eec96a 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -28,7 +28,7 @@ from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler from lightning.pytorch.strategies import DDPSpawnStrategy from lightning.pytorch.trainer import call from lightning.pytorch.trainer.states import RunningStage, TrainerFn -from lightning.pytorch.trainer.supporters import CombinedLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader, has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden diff --git a/src/lightning/pytorch/utilities/__init__.py b/src/lightning/pytorch/utilities/__init__.py index a6fd4f2230..98a16db95a 100644 --- a/src/lightning/pytorch/utilities/__init__.py +++ b/src/lightning/pytorch/utilities/__init__.py @@ -18,6 +18,7 @@ import numpy from lightning.fabric.utilities import LightningEnum # noqa: F401 from lightning.fabric.utilities import move_data_to_device # noqa: F401 from lightning.pytorch.accelerators.hpu import _HPU_AVAILABLE # noqa: F401 +from lightning.pytorch.utilities.combined_loader import CombinedLoader # noqa: F401 from lightning.pytorch.utilities.enums import GradClipAlgorithmType # noqa: F401 from lightning.pytorch.utilities.grads import grad_norm # noqa: F401 from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCHVISION_AVAILABLE # noqa: F401 diff --git a/src/lightning/pytorch/trainer/supporters.py b/src/lightning/pytorch/utilities/combined_loader.py similarity index 100% rename from src/lightning/pytorch/trainer/supporters.py rename to src/lightning/pytorch/utilities/combined_loader.py diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index c7912be6bc..bd9e2e8d6e 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -22,7 +22,7 @@ from lightning.pytorch import LightningDataModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher, _PrefetchDataFetcher from lightning.pytorch.profilers import SimpleProfiler -from lightning.pytorch.trainer.supporters import CombinedLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index e83fb644bc..de8fb78218 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -29,7 +29,7 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.strategies import DDPSpawnStrategy from lightning.pytorch.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache from lightning.pytorch.trainer.states import RunningStage, TrainerFn -from lightning.pytorch.trainer.supporters import CombinedLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _update_dataloader from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index dfeda2ef4f..f945168e70 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -34,7 +34,7 @@ from lightning.pytorch.demos.boring_classes import ( ) from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.trainer.states import RunningStage -from lightning.pytorch.trainer.supporters import CombinedLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader diff --git a/tests/tests_pytorch/trainer/test_supporters.py b/tests/tests_pytorch/utilities/test_combined_loader.py similarity index 99% rename from tests/tests_pytorch/trainer/test_supporters.py rename to tests/tests_pytorch/utilities/test_combined_loader.py index 1d58a09fdb..bb782ddadc 100644 --- a/tests/tests_pytorch/trainer/test_supporters.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -25,7 +25,13 @@ from torch.utils.data.sampler import RandomSampler, SequentialSampler from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.trainer.supporters import _MaxSizeCycle, _MinSize, _Sequential, _supported_modes, CombinedLoader +from lightning.pytorch.utilities.combined_loader import ( + _MaxSizeCycle, + _MinSize, + _Sequential, + _supported_modes, + CombinedLoader, +) from tests_pytorch.helpers.runif import RunIf