diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 84695d2012..dc01f59cd1 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -495,27 +495,28 @@ If you want each process to load the full dataset, ignore this warning. """ warnings.warn(msg) - if self.use_ddp and\ - not all(isinstance(dataloader, DistributedSampler) - for dataloader in self.val_dataloader): - msg = """ -You're val_dataloader(s) are not all DistributedSamplers. -You're using multiple gpus and multiple nodes without using a DistributedSampler -to assign a subset of your data to each process. To silence this warning, pass a -DistributedSampler to your DataLoader. + if self.use_ddp and self.val_dataloader is not None: + for dataloader in self.val_dataloader: + if not isinstance(dataloader, DistributedSampler): + msg = """ + Your val_dataloader(s) are not all DistributedSamplers. + You're using multiple gpus and multiple nodes without using a DistributedSampler + to assign a subset of your data to each process. To silence this warning, pass a + DistributedSampler to your DataLoader. -ie: this: -dataset = myDataset() -dataloader = Dataloader(dataset) + ie: this: + dataset = myDataset() + dataloader = Dataloader(dataset) -becomes: -dataset = myDataset() -dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) -dataloader = Dataloader(dataset, sampler=dist_sampler) + becomes: + dataset = myDataset() + dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + dataloader = Dataloader(dataset, sampler=dist_sampler) -If you want each process to load the full dataset, ignore this warning. -""" - warnings.warn(msg) + If you want each process to load the full dataset, ignore this warning. + """ + warnings.warn(msg) + break # ----------------------------- # MODEL TRAINING