diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index ad7b9ebe59..3b78306af3 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -277,7 +277,7 @@ class Trainer(TrainerIO): self.test_dataloader = model.test_dataloader self.val_dataloader = model.val_dataloader - if self.data_parallel and not issubclass(self.tng_dataloader.sampler, DistributedSampler): + if self.data_parallel and not isinstance(self.tng_dataloader.sampler, DistributedSampler): msg = ''' when using multiple gpus and multiple nodes you must pass a DistributedSampler to DataLoader(sampler).