From 2b16c75499aa697d90577b60dced8abcd5084512 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 8 Jul 2019 19:56:52 -0400 Subject: [PATCH] scaled batch size --- .../lightning_module_template.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 2eb73ba9a7..994c162121 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -162,16 +162,12 @@ class LightningTemplateModel(LightningModule): train_sampler = None batch_size = self.hparams.batch_size - # try: - # if self.on_gpu: - import torch.distributed as dist - print(dist.get_world_size()) - train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) - batch_size = batch_size // self.trainer.world_size # scale batch size - print(batch_size) - - # except Exception as e: - # pass + try: + if self.on_gpu: + train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) + batch_size = batch_size // self.trainer.world_size # scale batch size + except Exception as e: + pass should_shuffle = train_sampler is None loader = DataLoader(