scaled batch size
This commit is contained in:
parent
0bd9152e0a
commit
2b16c75499
|
@ -162,16 +162,12 @@ class LightningTemplateModel(LightningModule):
|
|||
train_sampler = None
|
||||
batch_size = self.hparams.batch_size
|
||||
|
||||
# try:
|
||||
# if self.on_gpu:
|
||||
import torch.distributed as dist
|
||||
print(dist.get_world_size())
|
||||
try:
|
||||
if self.on_gpu:
|
||||
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
||||
batch_size = batch_size // self.trainer.world_size # scale batch size
|
||||
print(batch_size)
|
||||
|
||||
# except Exception as e:
|
||||
# pass
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
should_shuffle = train_sampler is None
|
||||
loader = DataLoader(
|
||||
|
|
Loading…
Reference in New Issue