moved sampler

This commit is contained in:
William Falcon 2019-07-08 18:28:30 -04:00
parent bd2d1ddc07
commit 493a98d591
1 changed files with 3 additions and 1 deletions

View File

@ -161,8 +161,10 @@ class LightningTemplateModel(LightningModule):
# when using multi-node we need to add the datasampler
try:
if self.hparams.nb_gpu_nodes > 1:
print('-'*100)
print(self.trainer.world_size, self.trainer.proc_rank)
print('-'*100)
train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
print('using sampler')
except Exception as e:
print('no sampler')
train_sampler = None