moved sampler
This commit is contained in:
parent
bd2d1ddc07
commit
493a98d591
|
@ -161,8 +161,10 @@ class LightningTemplateModel(LightningModule):
|
|||
# when using multi-node we need to add the datasampler
|
||||
try:
|
||||
if self.hparams.nb_gpu_nodes > 1:
|
||||
print('-'*100)
|
||||
print(self.trainer.world_size, self.trainer.proc_rank)
|
||||
print('-'*100)
|
||||
train_sampler = DistributedSampler(dataset, num_replicas=self.trainer.world_size, rank=self.trainer.proc_rank)
|
||||
print('using sampler')
|
||||
except Exception as e:
|
||||
print('no sampler')
|
||||
train_sampler = None
|
||||
|
|
Loading…
Reference in New Issue