diff --git a/examples/new_project_templates/lightning_module_template.py b/examples/new_project_templates/lightning_module_template.py index c0157a774f..b78efa374b 100644 --- a/examples/new_project_templates/lightning_module_template.py +++ b/examples/new_project_templates/lightning_module_template.py @@ -210,7 +210,7 @@ class LightningTemplateModel(LightningModule): # network params - parser.add_argument('--nb_gpu_nodes', type=int, default=2) + parser.add_argument('--nb_gpu_nodes', type=int, default=1) parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) parser.add_argument('--in_features', default=28*28) parser.add_argument('--out_features', default=10) diff --git a/examples/new_project_templates/trainer_gpu_cluster_template.py b/examples/new_project_templates/trainer_gpu_cluster_template.py index 1d05da3d04..7db414a64c 100644 --- a/examples/new_project_templates/trainer_gpu_cluster_template.py +++ b/examples/new_project_templates/trainer_gpu_cluster_template.py @@ -165,7 +165,6 @@ if __name__ == '__main__': parent_parser.add_argument('-gpu_partition', type=str) parent_parser.add_argument('-per_experiment_nb_gpus', type=int) - # allow model to overwrite or extend args TRAINING_MODEL = AVAILABLE_MODELS[model_name] parser = TRAINING_MODEL.add_model_specific_args(parent_parser, root_dir) @@ -174,35 +173,39 @@ if __name__ == '__main__': # --------------------- # RUN TRAINING # --------------------- - # -1 means use all gpus - # otherwise use the visible ones - if hyperparams.gpus is not None: - if hyperparams.gpus == '-1': - gpu_ids = list(range(0, torch.cuda.device_count())) - else: - gpu_ids = hyperparams.gpus.split(',') - # cluster and CPU + # RUN ON CLUSTER if hyperparams.on_cluster: # run on HPC cluster print('RUNNING ON SLURM CLUSTER') optimize_on_cluster(hyperparams) - elif hyperparams.gpus is None: + # RUN ON GPUS + if hyperparams.gpus is not None: + # -1 means use all gpus + # otherwise use the visible ones + if hyperparams.gpus == '-1': + gpu_ids = list(range(0, torch.cuda.device_count())) + else: + gpu_ids = hyperparams.gpus.split(',') + + if hyperparams.interactive: + print(f'RUNNING INTERACTIVE MODE ON GPUS. gpu ids: {gpu_ids}') + main(hyperparams, None, None) + + else: + # multiple GPUs on same machine + print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}') + hyperparams.optimize_parallel_gpu( + main_local, + gpu_ids=gpu_ids, + nb_trials=hyperparams.nb_hopt_trials, + nb_workers=len(gpu_ids) + ) + + # RUN ON CPU + else: # run on cpu print('RUNNING ON CPU') main(hyperparams, None, None) - if hyperparams.interactive: - print(f'RUNNING INTERACTIVE MODE ON GPUS. gpu ids: {gpu_ids}') - main(hyperparams, None, None) - - else: - # multiple GPUs on same machine - print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}') - hyperparams.optimize_parallel_gpu( - main_local, - gpu_ids=gpu_ids, - nb_trials=hyperparams.nb_hopt_trials, - nb_workers=len(gpu_ids) - )