scaled batch size
This commit is contained in:
parent
a87073bffd
commit
0bd9152e0a
|
@ -168,6 +168,7 @@ class LightningTemplateModel(LightningModule):
|
|||
print(dist.get_world_size())
|
||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||
print(batch_size)
|
||||
|
||||
# except Exception as e:
|
||||
# pass
|
||||
|
|
Loading…
Reference in New Issue