scaled batch size
This commit is contained in:
parent
0bd9152e0a
commit
2b16c75499
|
@ -162,16 +162,12 @@ class LightningTemplateModel(LightningModule):
|
||||||
train_sampler = None
|
train_sampler = None
|
||||||
batch_size = self.hparams.batch_size
|
batch_size = self.hparams.batch_size
|
||||||
|
|
||||||
# try:
|
try:
|
||||||
# if self.on_gpu:
|
if self.on_gpu:
|
||||||
import torch.distributed as dist
|
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||||
print(dist.get_world_size())
|
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
except Exception as e:
|
||||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
pass
|
||||||
print(batch_size)
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# pass
|
|
||||||
|
|
||||||
should_shuffle = train_sampler is None
|
should_shuffle = train_sampler is None
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
|
|
Loading…
Reference in New Issue