diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 7715d5e270..1612a2c5c3 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -79,11 +79,11 @@ class TrainerDataLoadingMixin(ABC): if not isinstance(dataloader, DataLoader): return dataloader - # don't add sampler when user gives one - if dataloader.sampler is not None: - return dataloader + need_dist_sampler = self.use_ddp or self.use_ddp2 or self.use_tpu + no_sampler_added = dataloader.sampler is None + + if need_dist_sampler and no_sampler_added: - if self.use_ddp or self.use_ddp2 or self.use_tpu: dl_args = { 'dataset': dataloader.dataset, 'batch_size': dataloader.batch_size,