updated args

This commit is contained in:
William Falcon 2019-06-25 18:18:20 -04:00
parent b9d5397196
commit e58eee8d6a
1 changed files with 9 additions and 8 deletions

View File

@ -169,29 +169,30 @@ if __name__ == '__main__':
# format GPU layout # format GPU layout
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
gpu_ids = hyperparams.gpus.split(';')
# RUN TRAINING # RUN TRAINING
if hyperparams.on_cluster: if hyperparams.on_cluster:
# Gets called when running via HPC cluster # run on HPC cluster
print('RUNNING ON SLURM CLUSTER') print('RUNNING ON SLURM CLUSTER')
gpu_ids = hyperparams.gpus.split(';')
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids) os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids)
optimize_on_cluster(hyperparams) optimize_on_cluster(hyperparams)
elif hyperparams.gpus is None:
# run on cpu
print('RUNNING ON CPU')
main(hyperparams, None, None)
elif hyperparams.single_run_gpu: elif hyperparams.single_run_gpu:
# run on 1 gpu # run on 1 gpu
gpu_ids = hyperparams.gpus.split(';')
print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}') print(f'RUNNING 1 TRIAL ON GPU. gpu: {gpu_ids[0]}')
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0] os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids[0]
main(hyperparams, None, None) main(hyperparams, None, None)
elif hyperparams.local or hyperparams.single_run:
# run 1 trial but on CPU
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
print('RUNNING LOCALLY')
main(hyperparams, None, None)
else: else:
# multiple GPUs on same machine # multiple GPUs on same machine
gpu_ids = hyperparams.gpus.split(';')
print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}') print(f'RUNNING MULTI GPU. GPU ids: {gpu_ids}')
hyperparams.optimize_parallel_gpu( hyperparams.optimize_parallel_gpu(
main_local, main_local,