scaled batch size

This commit is contained in:
William Falcon 2019-07-08 19:56:52 -04:00
parent 0bd9152e0a
commit 2b16c75499
1 changed files with 6 additions and 10 deletions

View File

@ -162,16 +162,12 @@ class LightningTemplateModel(LightningModule):
train_sampler = None train_sampler = None
batch_size = self.hparams.batch_size batch_size = self.hparams.batch_size
# try: try:
# if self.on_gpu: if self.on_gpu:
import torch.distributed as dist train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
print(dist.get_world_size()) batch_size = batch_size // self.trainer.world_size # scale batch size
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) except Exception as e:
batch_size = batch_size // self.trainer.world_size # scale batch size pass
print(batch_size)
# except Exception as e:
# pass
should_shuffle = train_sampler is None should_shuffle = train_sampler is None
loader = DataLoader( loader = DataLoader(