From 96314cbf463560e9a6ce68322342f69a6a4ca941 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 8 Jul 2019 19:26:51 -0400 Subject: [PATCH] updated dist sampler --- .../lightning_module_template.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 72c740ccd7..2b3cd0a8a4 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -159,19 +159,11 @@ class LightningTemplateModel(LightningModule): dataset = MNIST(root=self.hparams.data_root, train=train, transform=transform, download=True) # when using multi-node we need to add the datasampler - print('-'*100) - print(self.trainer.world_size, self.trainer.proc_rank) - print('-'*100) - train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank) - # try: - # if self.hparams.nb_gpu_nodes > 1: - # print('-'*100) - # print(self.trainer.world_size, self.trainer.proc_rank) - # print('-'*100) - # train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank) - # except Exception as e: - # print('no sampler') - # train_sampler = None + try: + if self.hparams.nb_gpu_nodes > 1: + train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank) + except Exception as e: + train_sampler = None should_shuffle = train_sampler is None loader = DataLoader(