diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 994c162121..359c85a8c5 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -164,6 +164,7 @@ class LightningTemplateModel(LightningModule): try: if self.on_gpu: + print('distributing') train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) batch_size = batch_size // self.trainer.world_size # scale batch size except Exception as e: