updated dist sampler
This commit is contained in:
parent
d596ff2039
commit
96314cbf46
|
@ -159,19 +159,11 @@ class LightningTemplateModel(LightningModule):
|
|||
dataset = MNIST(root=self.hparams.data_root, train=train, transform=transform, download=True)
|
||||
|
||||
# when using multi-node we need to add the datasampler
|
||||
print('-'*100)
|
||||
print(self.trainer.world_size, self.trainer.proc_rank)
|
||||
print('-'*100)
|
||||
train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
|
||||
# try:
|
||||
# if self.hparams.nb_gpu_nodes > 1:
|
||||
# print('-'*100)
|
||||
# print(self.trainer.world_size, self.trainer.proc_rank)
|
||||
# print('-'*100)
|
||||
# train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
|
||||
# except Exception as e:
|
||||
# print('no sampler')
|
||||
# train_sampler = None
|
||||
try:
|
||||
if self.hparams.nb_gpu_nodes > 1:
|
||||
train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
|
||||
except Exception as e:
|
||||
train_sampler = None
|
||||
|
||||
should_shuffle = train_sampler is None
|
||||
loader = DataLoader(
|
||||
|
|
Loading…
Reference in New Issue