added warning for distributedsampler in case of evaluation (#11479)

This commit is contained in:
Rohit Gupta 2022-02-04 00:12:13 +05:30 committed by GitHub
parent a34930b772
commit 400201712f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 4 deletions

View File

@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for DDP when using a `CombinedLoader` for the training data ([#11648](https://github.com/PyTorchLightning/pytorch-lightning/pull/11648))
- Added a warning when using `DistributedSampler` during validation/testing ([#11479](https://github.com/PyTorchLightning/pytorch-lightning/pull/11479))
### Changed
- Set the `prog_bar` flag to False in `LightningModule.log_grad_norm` ([#11472](https://github.com/PyTorchLightning/pytorch-lightning/pull/11472))

View File

@ -50,8 +50,8 @@ To run the test set after training completes, use this method.
.. warning::
It is recommended to test on single device since Distributed Training such as DDP internally
uses :class:`~torch.utils.data.distributed.DistributedSampler` which replicates some samples to
It is recommended to test with ``Trainer(devices=1)`` since distributed strategies such as DDP
use :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to
make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure
benchmarking for research papers is done the right way.
@ -144,5 +144,12 @@ Apart from this ``.validate`` has same API as ``.test``, but would rely respecti
``.validate`` method uses the same validation logic being used under validation happening within
:meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` call.
.. warning::
When using ``trainer.validate()``, it is recommended to use ``Trainer(devices=1)`` since distributed strategies such as DDP
uses :class:`~torch.utils.data.distributed.DistributedSampler` internally, which replicates some samples to
make sure all devices have same batch size in case of uneven inputs. This is helpful to make sure
benchmarking for research papers is done the right way.
.. automethod:: pytorch_lightning.trainer.Trainer.validate
:noindex:

View File

@ -24,7 +24,7 @@ from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -386,7 +386,7 @@ class DataConnector:
" distributed training. Either remove the sampler from your DataLoader or set"
" `replace_sampler_ddp=False` if you want to use your custom sampler."
)
return self._get_distributed_sampler(
sampler = self._get_distributed_sampler(
dataloader,
shuffle,
mode=mode,
@ -394,6 +394,19 @@ class DataConnector:
**self.trainer.distributed_sampler_kwargs,
)
# update docs too once this is resolved
trainer_fn = self.trainer.state.fn
if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING):
rank_zero_warn(
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
" it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
" exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
" some samples to make sure all devices have same batch size in case of uneven inputs.",
category=PossibleUserWarning,
)
return sampler
return dataloader.sampler
@staticmethod

View File

@ -18,6 +18,8 @@ from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.connectors.data_connector import _DataLoaderSource
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.boring_model import RandomDataset
@ -69,6 +71,22 @@ def test_dataloader_source_request_from_module():
module.foo.assert_called_once()
def test_eval_distributed_sampler_warning(tmpdir):
"""Test that a warning is raised when `DistributedSampler` is used with evaluation."""
model = BoringModel()
trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True)
trainer._data_connector.attach_data(model)
trainer.state.fn = TrainerFn.VALIDATING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_val_dataloader(model)
trainer.state.fn = TrainerFn.TESTING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_test_dataloader(model)
@pytest.mark.parametrize("shuffle", [True, False])
def test_eval_shuffle_with_distributed_sampler_replacement(shuffle):
"""Test that shuffle is not changed if set to True."""