diff --git a/CHANGELOG.md b/CHANGELOG.md index 9050d7576a..d9c2e189c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -497,6 +497,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `SimpleProfiler` summary ([#11414](https://github.com/PyTorchLightning/pytorch-lightning/pull/11414)) +- Avoid enforcing `shuffle=False` for eval dataloaders ([#11575](https://github.com/PyTorchLightning/pytorch-lightning/pull/11575)) + + - Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507)) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 688f6c9a4b..8c12f94583 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -19,7 +19,6 @@ from typing import Any, Collection, Iterable, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler -from torch.utils.data.dataset import IterableDataset from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl @@ -35,6 +34,7 @@ from pytorch_lightning.utilities.auto_restart import ( ) from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, + _is_dataloader_shuffled, _replace_dataloader_init_method, _update_dataloader, has_iterable_dataset, @@ -329,7 +329,10 @@ class DataConnector: and not has_iterable_dataset(dataloader) ) - def _prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: + # TODO: shuffle here is kept for BC. Remove it once data_loading.py is removed (#11248) + def _prepare_dataloader( + self, dataloader: Any, shuffle: Optional[bool] = None, mode: Optional[RunningStage] = None + ) -> Any: """This function handles to following functionalities: - Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment @@ -360,6 +363,11 @@ class DataConnector: or mode == RunningStage.PREDICTING # to track indices for the predictions or self.trainer._accelerator_connector.use_ipu # IPUs use a custom `DataLoader` ): + if shuffle is None: + # for training, set to True always + # for evaluation, decide based on existing sampler + shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled(dataloader) + sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode) dataloader = _update_dataloader(dataloader, sampler, mode=mode) @@ -435,7 +443,7 @@ class DataConnector: ) # add samplers - dataloaders = [self._prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None] + dataloaders = [self._prepare_dataloader(dl, mode=mode) for dl in dataloaders if dl is not None] # add worker_init_fn for correct seeding in worker processes apply_to_collection( @@ -531,11 +539,7 @@ class DataConnector: @staticmethod def _check_eval_shuffling(dataloader, mode): - if ( - hasattr(dataloader, "sampler") - and not isinstance(dataloader.sampler, SequentialSampler) - and not isinstance(dataloader.dataset, IterableDataset) - ): + if _is_dataloader_shuffled(dataloader): rank_zero_warn( f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`," " it is strongly recommended that you turn this off for val/test/predict dataloaders.", diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a41bc9f6ed..4a49723bf0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1768,7 +1768,6 @@ class Trainer( self.train_dataloader, (DataLoader, CombinedLoader), self._data_connector._prepare_dataloader, - shuffle=True, mode=RunningStage.TRAINING, ) loaders = ( diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index b365ab99f3..7e698da8e4 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -20,7 +20,7 @@ from itertools import chain from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union import torch -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, SequentialSampler import pytorch_lightning as pl from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper @@ -374,3 +374,11 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> else: raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.") return dl_kwargs + + +def _is_dataloader_shuffled(dataloader: DataLoader): + return ( + hasattr(dataloader, "sampler") + and not isinstance(dataloader.sampler, SequentialSampler) + and not isinstance(dataloader.dataset, IterableDataset) + ) diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 4d614ecc25..429ac9237a 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -19,6 +19,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource from tests.helpers import BoringDataModule, BoringModel +from tests.helpers.boring_model import RandomDataset class NoDataLoaderModel(BoringModel): @@ -66,3 +67,18 @@ def test_dataloader_source_request_from_module(): module.foo.assert_not_called() assert isinstance(source.dataloader(), DataLoader) module.foo.assert_called_once() + + +@pytest.mark.parametrize("shuffle", [True, False]) +def test_eval_shuffle_with_distributed_sampler_replacement(shuffle): + """Test that shuffle is not changed if set to True.""" + + class CustomModel(BoringModel): + def val_dataloader(self): + return DataLoader(RandomDataset(32, 64), shuffle=shuffle) + + trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp") + model = CustomModel() + trainer._data_connector.attach_data(model) + trainer.reset_val_dataloader(model) + assert trainer.val_dataloaders[0].sampler.shuffle == shuffle