updated args

This commit is contained in:
William Falcon 2019-06-25 18:20:24 -04:00
parent e58cfafa74
commit 35aa67df56
1 changed files with 4 additions and 3 deletions

View File

@ -174,6 +174,7 @@ if __name__ == '__main__':
# RUN TRAINING # RUN TRAINING
pdb.set_trace() pdb.set_trace()
# cluster and CPU
if hyperparams.on_cluster: if hyperparams.on_cluster:
# run on HPC cluster # run on HPC cluster
print('RUNNING ON SLURM CLUSTER') print('RUNNING ON SLURM CLUSTER')
@ -186,16 +187,16 @@ if __name__ == '__main__':
print('RUNNING ON CPU') print('RUNNING ON CPU')
main(hyperparams, None, None) main(hyperparams, None, None)
elif hyperparams.single_run_gpu: # single or multiple GPUs on same machine
gpu_ids = hyperparams.gpus.split(';')
if len(gpu_ids) == 1:
# run on 1 gpu # run on 1 gpu
gpu_ids = hyperparams.gpus.split(';')
print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}') print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}')
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0] os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0]
main(hyperparams, None, None) main(hyperparams, None, None)
else: else:
# multiple GPUs on same machine # multiple GPUs on same machine
gpu_ids = hyperparams.gpus.split(';')
print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}') print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}')
hyperparams.optimize_parallel_gpu( hyperparams.optimize_parallel_gpu(
main_local, main_local,