diff --git a/examples/new_project_templates/trainer_gpu_cluster_template.py b/examples/new_project_templates/trainer_gpu_cluster_template.py index 6a75622d36..9b03fa7afc 100644 --- a/examples/new_project_templates/trainer_gpu_cluster_template.py +++ b/examples/new_project_templates/trainer_gpu_cluster_template.py @@ -89,18 +89,13 @@ def main(hparams, cluster, results_dict): mode=hparams.model_save_monitor_mode ) - # gpus are ; separated for inside a node and , within nodes - gpu_list = None - if hparams.gpus is not None: - gpu_list = [int(x) for x in hparams.gpus.split(';')] - # configure trainer trainer = Trainer( experiment=exp, cluster=cluster, checkpoint_callback=checkpoint, early_stop_callback=early_stop, - gpus=gpu_list, + gpus=hparams.gpus, nb_gpu_nodes=1 ) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index cacc522143..f79571d4ac 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -86,8 +86,17 @@ class Trainer(TrainerIO): self.lr_schedulers = [] self.amp_level = amp_level self.print_nan_grads = print_nan_grads - self.data_parallel_device_ids = gpus self.data_parallel = gpus is not None and len(gpus) > 0 + self.data_parallel_device_ids = gpus + + # gpus come in as a string. + # if gpus = -1 then use all available devices + # otherwise, split the string using commas + if gpus is not None: + if gpus == '-1': + self.data_parallel_device_ids = torch.cuda.device_count() + else: + self.data_parallel_device_ids = [int(x.strip()) for x in gpus.split(',')] # process info self.proc_rank = 0