From 35aa67df5626cafd0c5b7fda7118f92ad03091bb Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 25 Jun 2019 18:20:24 -0400 Subject: [PATCH] updated args --- docs/source/examples/fully_featured_trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/examples/fully_featured_trainer.py b/docs/source/examples/fully_featured_trainer.py index ab346b47fb..68ceb6376b 100644 --- a/docs/source/examples/fully_featured_trainer.py +++ b/docs/source/examples/fully_featured_trainer.py @@ -174,6 +174,7 @@ if __name__ == '__main__': # RUN TRAINING pdb.set_trace() + # cluster and CPU if hyperparams.on_cluster: # run on HPC cluster print('RUNNING ON SLURM CLUSTER') @@ -186,16 +187,16 @@ if __name__ == '__main__': print('RUNNING ON CPU') main(hyperparams, None, None) - elif hyperparams.single_run_gpu: + # single or multiple GPUs on same machine + gpu_ids = hyperparams.gpus.split(';') + if len(gpu_ids) == 1: # run on 1 gpu - gpu_ids = hyperparams.gpus.split(';') print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}') os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0] main(hyperparams, None, None) else: # multiple GPUs on same machine - gpu_ids = hyperparams.gpus.split(';') print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}') hyperparams.optimize_parallel_gpu( main_local,