Avoid enforcing `shuffle=False` for eval dataloaders (#11575)
This commit is contained in:
parent
9ebd7df22a
commit
7948ed703d
|
@ -497,6 +497,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `SimpleProfiler` summary ([#11414](https://github.com/PyTorchLightning/pytorch-lightning/pull/11414))
|
||||
|
||||
|
||||
- Avoid enforcing `shuffle=False` for eval dataloaders ([#11575](https://github.com/PyTorchLightning/pytorch-lightning/pull/11575))
|
||||
|
||||
|
||||
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@ from typing import Any, Collection, Iterable, List, Optional, Tuple, Union
|
|||
from weakref import proxy
|
||||
|
||||
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -35,6 +34,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
)
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_is_dataloader_shuffled,
|
||||
_replace_dataloader_init_method,
|
||||
_update_dataloader,
|
||||
has_iterable_dataset,
|
||||
|
@ -329,7 +329,10 @@ class DataConnector:
|
|||
and not has_iterable_dataset(dataloader)
|
||||
)
|
||||
|
||||
def _prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
|
||||
# TODO: shuffle here is kept for BC. Remove it once data_loading.py is removed (#11248)
|
||||
def _prepare_dataloader(
|
||||
self, dataloader: Any, shuffle: Optional[bool] = None, mode: Optional[RunningStage] = None
|
||||
) -> Any:
|
||||
"""This function handles to following functionalities:
|
||||
|
||||
- Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment
|
||||
|
@ -360,6 +363,11 @@ class DataConnector:
|
|||
or mode == RunningStage.PREDICTING # to track indices for the predictions
|
||||
or self.trainer._accelerator_connector.use_ipu # IPUs use a custom `DataLoader`
|
||||
):
|
||||
if shuffle is None:
|
||||
# for training, set to True always
|
||||
# for evaluation, decide based on existing sampler
|
||||
shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled(dataloader)
|
||||
|
||||
sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
|
||||
dataloader = _update_dataloader(dataloader, sampler, mode=mode)
|
||||
|
||||
|
@ -435,7 +443,7 @@ class DataConnector:
|
|||
)
|
||||
|
||||
# add samplers
|
||||
dataloaders = [self._prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]
|
||||
dataloaders = [self._prepare_dataloader(dl, mode=mode) for dl in dataloaders if dl is not None]
|
||||
|
||||
# add worker_init_fn for correct seeding in worker processes
|
||||
apply_to_collection(
|
||||
|
@ -531,11 +539,7 @@ class DataConnector:
|
|||
|
||||
@staticmethod
|
||||
def _check_eval_shuffling(dataloader, mode):
|
||||
if (
|
||||
hasattr(dataloader, "sampler")
|
||||
and not isinstance(dataloader.sampler, SequentialSampler)
|
||||
and not isinstance(dataloader.dataset, IterableDataset)
|
||||
):
|
||||
if _is_dataloader_shuffled(dataloader):
|
||||
rank_zero_warn(
|
||||
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
|
||||
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",
|
||||
|
|
|
@ -1768,7 +1768,6 @@ class Trainer(
|
|||
self.train_dataloader,
|
||||
(DataLoader, CombinedLoader),
|
||||
self._data_connector._prepare_dataloader,
|
||||
shuffle=True,
|
||||
mode=RunningStage.TRAINING,
|
||||
)
|
||||
loaders = (
|
||||
|
|
|
@ -20,7 +20,7 @@ from itertools import chain
|
|||
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, SequentialSampler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
|
||||
|
@ -374,3 +374,11 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) ->
|
|||
else:
|
||||
raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
|
||||
return dl_kwargs
|
||||
|
||||
|
||||
def _is_dataloader_shuffled(dataloader: DataLoader):
|
||||
return (
|
||||
hasattr(dataloader, "sampler")
|
||||
and not isinstance(dataloader.sampler, SequentialSampler)
|
||||
and not isinstance(dataloader.dataset, IterableDataset)
|
||||
)
|
||||
|
|
|
@ -19,6 +19,7 @@ from torch.utils.data import DataLoader
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers.boring_model import RandomDataset
|
||||
|
||||
|
||||
class NoDataLoaderModel(BoringModel):
|
||||
|
@ -66,3 +67,18 @@ def test_dataloader_source_request_from_module():
|
|||
module.foo.assert_not_called()
|
||||
assert isinstance(source.dataloader(), DataLoader)
|
||||
module.foo.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shuffle", [True, False])
|
||||
def test_eval_shuffle_with_distributed_sampler_replacement(shuffle):
|
||||
"""Test that shuffle is not changed if set to True."""
|
||||
|
||||
class CustomModel(BoringModel):
|
||||
def val_dataloader(self):
|
||||
return DataLoader(RandomDataset(32, 64), shuffle=shuffle)
|
||||
|
||||
trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp")
|
||||
model = CustomModel()
|
||||
trainer._data_connector.attach_data(model)
|
||||
trainer.reset_val_dataloader(model)
|
||||
assert trainer.val_dataloaders[0].sampler.shuffle == shuffle
|
||||
|
|
Loading…
Reference in New Issue