From e48422df3821f0202d07950721fb8d90f39a7f38 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 1 Apr 2020 12:57:36 -0400 Subject: [PATCH] Sampler (#1328) * sampler * check for dataloader type * check for dataloader type * fixed sampler cases --- pytorch_lightning/trainer/data_loading.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 7715d5e270..1612a2c5c3 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -79,11 +79,11 @@ class TrainerDataLoadingMixin(ABC): if not isinstance(dataloader, DataLoader): return dataloader - # don't add sampler when user gives one - if dataloader.sampler is not None: - return dataloader + need_dist_sampler = self.use_ddp or self.use_ddp2 or self.use_tpu + no_sampler_added = dataloader.sampler is None + + if need_dist_sampler and no_sampler_added: - if self.use_ddp or self.use_ddp2 or self.use_tpu: dl_args = { 'dataset': dataloader.dataset, 'batch_size': dataloader.batch_size,