diff --git a/docs/source/examples/fully_featured_trainer.py b/docs/source/examples/fully_featured_trainer.py index ab346b47fb..68ceb6376b 100644 --- a/docs/source/examples/fully_featured_trainer.py +++ b/docs/source/examples/fully_featured_trainer.py @@ -174,6 +174,7 @@ if __name__ == '__main__': # RUN TRAINING pdb.set_trace() + # cluster and CPU if hyperparams.on_cluster: # run on HPC cluster print('RUNNING ON SLURM CLUSTER') @@ -186,16 +187,16 @@ if __name__ == '__main__': print('RUNNING ON CPU') main(hyperparams, None, None) - elif hyperparams.single_run_gpu: + # single or multiple GPUs on same machine + gpu_ids = hyperparams.gpus.split(';') + if len(gpu_ids) == 1: # run on 1 gpu - gpu_ids = hyperparams.gpus.split(';') print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}') os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0] main(hyperparams, None, None) else: # multiple GPUs on same machine - gpu_ids = hyperparams.gpus.split(';') print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}') hyperparams.optimize_parallel_gpu( main_local,