scaled batch size
This commit is contained in:
parent
a87073bffd
commit
0bd9152e0a
|
@ -168,6 +168,7 @@ class LightningTemplateModel(LightningModule):
|
||||||
print(dist.get_world_size())
|
print(dist.get_world_size())
|
||||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||||
|
print(batch_size)
|
||||||
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pass
|
# pass
|
||||||
|
|
Loading…
Reference in New Issue