diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 72c740ccd7..2b3cd0a8a4 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -159,19 +159,11 @@ class LightningTemplateModel(LightningModule): dataset = MNIST(root=self.hparams.data_root, train=train, transform=transform, download=True) # when using multi-node we need to add the datasampler - print('-'*100) - print(self.trainer.world_size, self.trainer.proc_rank) - print('-'*100) - train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank) - # try: - # if self.hparams.nb_gpu_nodes > 1: - # print('-'*100) - # print(self.trainer.world_size, self.trainer.proc_rank) - # print('-'*100) - # train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank) - # except Exception as e: - # print('no sampler') - # train_sampler = None + try: + if self.hparams.nb_gpu_nodes > 1: + train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank) + except Exception as e: + train_sampler = None should_shuffle = train_sampler is None loader = DataLoader(