diff --git a/examples/new_project_templates/trainer_gpu_cluster_template.py b/examples/new_project_templates/trainer_gpu_cluster_template.py index 9b03fa7afc..4a357662f9 100644 --- a/examples/new_project_templates/trainer_gpu_cluster_template.py +++ b/examples/new_project_templates/trainer_gpu_cluster_template.py @@ -171,13 +171,18 @@ if __name__ == '__main__': # --------------------- # RUN TRAINING # --------------------- + # -1 means use all gpus + # otherwise use the visible ones + if hyperparams.gpus is not None: + if hyperparams.gpus == '-1': + gpu_ids = list(range(0, torch.cuda.device_count())) + else: + gpu_ids = hyperparams.gpus.split(',') # cluster and CPU if hyperparams.on_cluster: # run on HPC cluster print('RUNNING ON SLURM CLUSTER') - gpu_ids = hyperparams.gpus.split(';') - os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids) optimize_on_cluster(hyperparams) elif hyperparams.gpus is None: @@ -185,12 +190,8 @@ if __name__ == '__main__': print('RUNNING ON CPU') main(hyperparams, None, None) - # single or multiple GPUs on same machine - gpu_ids = hyperparams.gpus.split(';') if hyperparams.interactive: - # run on 1 gpu print(f'RUNNING INTERACTIVE MODE ON GPUS. gpu ids: {gpu_ids}') - os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids) main(hyperparams, None, None) else: diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 77a8c91cdf..77beb7161f 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -99,6 +99,10 @@ class Trainer(TrainerIO): self.data_parallel = self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) > 0 + # set the correct cuda visible devices (using pci order) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(self.data_parallel_device_ids) + # process info self.proc_rank = 0