Fix shuffle for distributed sampler (#2789)

* Fix shuffle for distributed sampler

* add test

* test

* chlog

* update test

* update test

* update test

* assertions via callback

* define callback outside for pickling

* skip ddp test on windows

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Rohit Gupta 2020-08-02 08:52:57 +05:30 committed by GitHub
parent 38fce2ea68
commit 8baec1a191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 3 deletions

View File

@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))
- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))
## [0.8.5] - 2020-07-09
### Added

View File

@ -163,7 +163,7 @@ class TrainerDataLoadingMixin(ABC):
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
# replace with distributed sampler
sampler = self._get_distributed_sampler(dataloader)
sampler = self._get_distributed_sampler(dataloader, train)
dataloader = self.replace_sampler(dataloader, sampler)
return dataloader
@ -179,7 +179,7 @@ class TrainerDataLoadingMixin(ABC):
dataloader = type(dataloader)(**dl_args)
return dataloader
def _get_distributed_sampler(self, dataloader):
def _get_distributed_sampler(self, dataloader, train):
if self.use_tpu:
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
elif self.use_horovod:
@ -193,6 +193,8 @@ class TrainerDataLoadingMixin(ABC):
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
kwargs['shuffle'] = train
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

View File

@ -7,9 +7,10 @@ import torch
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.distributed import DistributedSampler
import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
@ -640,6 +641,42 @@ def test_dataloader_reinit_for_subclass(tmpdir):
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
class DistribSamplerCallback(Callback):
def on_train_start(self, trainer, pl_module):
train_sampler = trainer.train_dataloader.sampler
assert isinstance(train_sampler, DistributedSampler)
assert train_sampler.shuffle
def on_validation_start(self, trainer, pl_module):
val_sampler = trainer.val_dataloaders[0].sampler
assert isinstance(val_sampler, DistributedSampler)
assert not val_sampler.shuffle
def on_test_start(self, trainer, pl_module):
test_sampler = trainer.test_dataloaders[0].sampler
assert isinstance(test_sampler, DistributedSampler)
assert not test_sampler.shuffle
@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_distributed_sampler(tmpdir):
""" Test DistributedSampler and it's arguments for DDP backend """
model = EvalModelTemplate()
trainer = Trainer(
gpus=[0, 1],
num_nodes=1,
distributed_backend='ddp_spawn',
default_root_dir=tmpdir,
max_steps=1,
callbacks=[DistribSamplerCallback()]
)
trainer.fit(model)
trainer.test(ckpt_path=None)
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
def test_batch_size_smaller_than_num_gpus(tmpdir):
# we need at least 3 gpus for this test