updated dist sampler

This commit is contained in:
William Falcon 2019-07-08 19:26:51 -04:00
parent d596ff2039
commit 96314cbf46
1 changed files with 5 additions and 13 deletions

View File

@ -159,19 +159,11 @@ class LightningTemplateModel(LightningModule):
dataset = MNIST(root=self.hparams.data_root, train=train, transform=transform, download=True)
# when using multi-node we need to add the datasampler
print('-'*100)
print(self.trainer.world_size, self.trainer.proc_rank)
print('-'*100)
train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
# try:
# if self.hparams.nb_gpu_nodes > 1:
# print('-'*100)
# print(self.trainer.world_size, self.trainer.proc_rank)
# print('-'*100)
# train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
# except Exception as e:
# print('no sampler')
# train_sampler = None
try:
if self.hparams.nb_gpu_nodes > 1:
train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
except Exception as e:
train_sampler = None
should_shuffle = train_sampler is None
loader = DataLoader(