diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index 994c162121..5245e2cd12 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -232,11 +232,11 @@ class LightningTemplateModel(LightningModule): parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str) # training params (opt) - parser.opt_list('--learning_rate', default=0.001, type=float, options=[0.0001, 0.0005, 0.001, 0.005], + parser.opt_list('--learning_rate', default=0.001*8, type=float, options=[0.0001, 0.0005, 0.001, 0.005], tunable=False) parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False) # if using 2 nodes with 4 gpus each the batch size here (256) will be 256 / (2*8) = 16 per gpu - parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False, + parser.opt_list('--batch_size', default=256*8, type=int, options=[32, 64, 128, 256], tunable=False, help='batch size will be divided over all the gpus being used across all nodes') return parser