From 400201712fc13079d8d552fa121daa7a2a167e56 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Fri, 4 Feb 2022 00:12:13 +0530 Subject: [PATCH] added warning for distributedsampler in case of evaluation (#11479) --- CHANGELOG.md | 3 +++ docs/source/common/test_set.rst | 11 +++++++++-- .../trainer/connectors/data_connector.py | 17 +++++++++++++++-- .../trainer/connectors/test_data_connector.py | 18 ++++++++++++++++++ 4 files changed, 45 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afb7c55dc9..8455f242b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for DDP when using a `CombinedLoader` for the training data ([#11648](https://github.com/PyTorchLightning/pytorch-lightning/pull/11648)) +- Added a warning when using `DistributedSampler` during validation/testing ([#11479](https://github.com/PyTorchLightning/pytorch-lightning/pull/11479)) + + ### Changed - Set the `prog_bar` flag to False in `LightningModule.log_grad_norm` ([#11472](https://github.com/PyTorchLightning/pytorch-lightning/pull/11472)) diff --git a/docs/source/common/test_set.rst b/docs/source/common/test_set.rst index 72080f70c1..fa14ddfec0 100644 --- a/docs/source/common/test_set.rst +++ b/docs/source/common/test_set.rst @@ -50,8 +50,8 @@ To run the test set after training completes, use this method. .. warning:: - It is recommended to test on single device since Distributed Training such as DDP internally - uses :class:`~torch.utils.data.distributed.DistributedSampler` which replicates some samples to + It is recommended to test with ``Trainer(devices=1)`` since distributed strategies such as DDP + use :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure benchmarking for research papers is done the right way. @@ -144,5 +144,12 @@ Apart from this ``.validate`` has same API as ``.test``, but would rely respecti ``.validate`` method uses the same validation logic being used under validation happening within :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` call. +.. warning:: + + When using ``trainer.validate()``, it is recommended to use ``Trainer(devices=1)`` since distributed strategies such as DDP + uses :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to + make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure + benchmarking for research papers is done the right way. + .. automethod:: pytorch_lightning.trainer.Trainer.validate :noindex: diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 8c12f94583..c83060244c 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -24,7 +24,7 @@ from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -386,7 +386,7 @@ class DataConnector: " distributed training. Either remove the sampler from your DataLoader or set" " `replace_sampler_ddp=False` if you want to use your custom sampler." ) - return self._get_distributed_sampler( + sampler = self._get_distributed_sampler( dataloader, shuffle, mode=mode, @@ -394,6 +394,19 @@ class DataConnector: **self.trainer.distributed_sampler_kwargs, ) + # update docs too once this is resolved + trainer_fn = self.trainer.state.fn + if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING): + rank_zero_warn( + f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`," + " it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated" + " exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates" + " some samples to make sure all devices have same batch size in case of uneven inputs.", + category=PossibleUserWarning, + ) + + return sampler + return dataloader.sampler @staticmethod diff --git a/tests/trainer/connectors/test_data_connector.py b/tests/trainer/connectors/test_data_connector.py index 429ac9237a..a2cb02e423 100644 --- a/tests/trainer/connectors/test_data_connector.py +++ b/tests/trainer/connectors/test_data_connector.py @@ -18,6 +18,8 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource +from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests.helpers import BoringDataModule, BoringModel from tests.helpers.boring_model import RandomDataset @@ -69,6 +71,22 @@ def test_dataloader_source_request_from_module(): module.foo.assert_called_once() +def test_eval_distributed_sampler_warning(tmpdir): + """Test that a warning is raised when `DistributedSampler` is used with evaluation.""" + + model = BoringModel() + trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True) + trainer._data_connector.attach_data(model) + + trainer.state.fn = TrainerFn.VALIDATING + with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): + trainer.reset_val_dataloader(model) + + trainer.state.fn = TrainerFn.TESTING + with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): + trainer.reset_test_dataloader(model) + + @pytest.mark.parametrize("shuffle", [True, False]) def test_eval_shuffle_with_distributed_sampler_replacement(shuffle): """Test that shuffle is not changed if set to True."""