Avoid enforcing `shuffle=False` for eval dataloaders (#11575)

This commit is contained in:
Rohit Gupta 2022-02-03 15:05:31 +05:30 committed by GitHub
parent 9ebd7df22a
commit 7948ed703d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 10 deletions

View File

@ -497,6 +497,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `SimpleProfiler` summary ([#11414](https://github.com/PyTorchLightning/pytorch-lightning/pull/11414))
- Avoid enforcing `shuffle=False` for eval dataloaders ([#11575](https://github.com/PyTorchLightning/pytorch-lightning/pull/11575))
- Disbled sampler replacement when using `IterableDataset` ([#11507](https://github.com/PyTorchLightning/pytorch-lightning/pull/11507))

View File

@ -19,7 +19,6 @@ from typing import Any, Collection, Iterable, List, Optional, Tuple, Union
from weakref import proxy
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataset import IterableDataset
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
@ -35,6 +34,7 @@ from pytorch_lightning.utilities.auto_restart import (
)
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_is_dataloader_shuffled,
_replace_dataloader_init_method,
_update_dataloader,
has_iterable_dataset,
@ -329,7 +329,10 @@ class DataConnector:
and not has_iterable_dataset(dataloader)
)
def _prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
# TODO: shuffle here is kept for BC. Remove it once data_loading.py is removed (#11248)
def _prepare_dataloader(
self, dataloader: Any, shuffle: Optional[bool] = None, mode: Optional[RunningStage] = None
) -> Any:
"""This function handles to following functionalities:
- Injecting a `DistributedDataSampler` into the `DataLoader` if on a distributed environment
@ -360,6 +363,11 @@ class DataConnector:
or mode == RunningStage.PREDICTING # to track indices for the predictions
or self.trainer._accelerator_connector.use_ipu # IPUs use a custom `DataLoader`
):
if shuffle is None:
# for training, set to True always
# for evaluation, decide based on existing sampler
shuffle = True if mode == RunningStage.TRAINING else _is_dataloader_shuffled(dataloader)
sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
dataloader = _update_dataloader(dataloader, sampler, mode=mode)
@ -435,7 +443,7 @@ class DataConnector:
)
# add samplers
dataloaders = [self._prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]
dataloaders = [self._prepare_dataloader(dl, mode=mode) for dl in dataloaders if dl is not None]
# add worker_init_fn for correct seeding in worker processes
apply_to_collection(
@ -531,11 +539,7 @@ class DataConnector:
@staticmethod
def _check_eval_shuffling(dataloader, mode):
if (
hasattr(dataloader, "sampler")
and not isinstance(dataloader.sampler, SequentialSampler)
and not isinstance(dataloader.dataset, IterableDataset)
):
if _is_dataloader_shuffled(dataloader):
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
" it is strongly recommended that you turn this off for val/test/predict dataloaders.",

View File

@ -1768,7 +1768,6 @@ class Trainer(
self.train_dataloader,
(DataLoader, CombinedLoader),
self._data_connector._prepare_dataloader,
shuffle=True,
mode=RunningStage.TRAINING,
)
loaders = (

View File

@ -20,7 +20,7 @@ from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union
import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler, SequentialSampler
import pytorch_lightning as pl
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
@ -374,3 +374,11 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) ->
else:
raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
return dl_kwargs
def _is_dataloader_shuffled(dataloader: DataLoader):
return (
hasattr(dataloader, "sampler")
and not isinstance(dataloader.sampler, SequentialSampler)
and not isinstance(dataloader.dataset, IterableDataset)
)

View File

@ -19,6 +19,7 @@ from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.boring_model import RandomDataset
class NoDataLoaderModel(BoringModel):
@ -66,3 +67,18 @@ def test_dataloader_source_request_from_module():
module.foo.assert_not_called()
assert isinstance(source.dataloader(), DataLoader)
module.foo.assert_called_once()
@pytest.mark.parametrize("shuffle", [True, False])
def test_eval_shuffle_with_distributed_sampler_replacement(shuffle):
"""Test that shuffle is not changed if set to True."""
class CustomModel(BoringModel):
def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), shuffle=shuffle)
trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp")
model = CustomModel()
trainer._data_connector.attach_data(model)
trainer.reset_val_dataloader(model)
assert trainer.val_dataloaders[0].sampler.shuffle == shuffle