scaled batch size

This commit is contained in:
William Falcon 2019-07-08 19:56:52 -04:00
parent 0bd9152e0a
commit 2b16c75499
1 changed files with 6 additions and 10 deletions

View File

@ -162,16 +162,12 @@ class LightningTemplateModel(LightningModule):
train_sampler = None
batch_size = self.hparams.batch_size
# try:
# if self.on_gpu:
import torch.distributed as dist
print(dist.get_world_size())
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size
print(batch_size)
# except Exception as e:
# pass
try:
if self.on_gpu:
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size
except Exception as e:
pass
should_shuffle = train_sampler is None
loader = DataLoader(