Fix shuffle for distributed sampler (#2789)
* Fix shuffle for distributed sampler * add test * test * chlog * update test * update test * update test * assertions via callback * define callback outside for pickling * skip ddp test on windows Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
38fce2ea68
commit
8baec1a191
|
@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))
|
||||
|
||||
- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))
|
||||
|
||||
## [0.8.5] - 2020-07-09
|
||||
|
||||
### Added
|
||||
|
|
|
@ -163,7 +163,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
|
||||
|
||||
# replace with distributed sampler
|
||||
sampler = self._get_distributed_sampler(dataloader)
|
||||
sampler = self._get_distributed_sampler(dataloader, train)
|
||||
dataloader = self.replace_sampler(dataloader, sampler)
|
||||
|
||||
return dataloader
|
||||
|
@ -179,7 +179,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
dataloader = type(dataloader)(**dl_args)
|
||||
return dataloader
|
||||
|
||||
def _get_distributed_sampler(self, dataloader):
|
||||
def _get_distributed_sampler(self, dataloader, train):
|
||||
if self.use_tpu:
|
||||
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||
elif self.use_horovod:
|
||||
|
@ -193,6 +193,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
}
|
||||
assert self.distributed_backend is not None
|
||||
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
|
||||
|
||||
kwargs['shuffle'] = train
|
||||
sampler = DistributedSampler(dataloader.dataset, **kwargs)
|
||||
return sampler
|
||||
|
||||
|
|
|
@ -7,9 +7,10 @@ import torch
|
|||
from packaging.version import parse
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.data.dataset import IterableDataset, Subset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import tests.base.develop_pipelines as tpipes
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
|
@ -640,6 +641,42 @@ def test_dataloader_reinit_for_subclass(tmpdir):
|
|||
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
|
||||
|
||||
|
||||
class DistribSamplerCallback(Callback):
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
train_sampler = trainer.train_dataloader.sampler
|
||||
assert isinstance(train_sampler, DistributedSampler)
|
||||
assert train_sampler.shuffle
|
||||
|
||||
def on_validation_start(self, trainer, pl_module):
|
||||
val_sampler = trainer.val_dataloaders[0].sampler
|
||||
assert isinstance(val_sampler, DistributedSampler)
|
||||
assert not val_sampler.shuffle
|
||||
|
||||
def on_test_start(self, trainer, pl_module):
|
||||
test_sampler = trainer.test_dataloaders[0].sampler
|
||||
assert isinstance(test_sampler, DistributedSampler)
|
||||
assert not test_sampler.shuffle
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
|
||||
def test_dataloader_distributed_sampler(tmpdir):
|
||||
""" Test DistributedSampler and it's arguments for DDP backend """
|
||||
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer(
|
||||
gpus=[0, 1],
|
||||
num_nodes=1,
|
||||
distributed_backend='ddp_spawn',
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=1,
|
||||
callbacks=[DistribSamplerCallback()]
|
||||
)
|
||||
trainer.fit(model)
|
||||
trainer.test(ckpt_path=None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
|
||||
def test_batch_size_smaller_than_num_gpus(tmpdir):
|
||||
# we need at least 3 gpus for this test
|
||||
|
|
Loading…
Reference in New Issue