moved cuda flags inside trainer
This commit is contained in:
parent
523cc9f2be
commit
f5a87c5016
|
@ -171,13 +171,18 @@ if __name__ == '__main__':
|
|||
# ---------------------
|
||||
# RUN TRAINING
|
||||
# ---------------------
|
||||
# -1 means use all gpus
|
||||
# otherwise use the visible ones
|
||||
if hyperparams.gpus is not None:
|
||||
if hyperparams.gpus == '-1':
|
||||
gpu_ids = list(range(0, torch.cuda.device_count()))
|
||||
else:
|
||||
gpu_ids = hyperparams.gpus.split(',')
|
||||
|
||||
# cluster and CPU
|
||||
if hyperparams.on_cluster:
|
||||
# run on HPC cluster
|
||||
print('RUNNING ON SLURM CLUSTER')
|
||||
gpu_ids = hyperparams.gpus.split(';')
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids)
|
||||
optimize_on_cluster(hyperparams)
|
||||
|
||||
elif hyperparams.gpus is None:
|
||||
|
@ -185,12 +190,8 @@ if __name__ == '__main__':
|
|||
print('RUNNING ON CPU')
|
||||
main(hyperparams, None, None)
|
||||
|
||||
# single or multiple GPUs on same machine
|
||||
gpu_ids = hyperparams.gpus.split(';')
|
||||
if hyperparams.interactive:
|
||||
# run on 1 gpu
|
||||
print(f'RUNNING INTERACTIVE MODE ON GPUS. gpu ids: {gpu_ids}')
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpu_ids)
|
||||
main(hyperparams, None, None)
|
||||
|
||||
else:
|
||||
|
|
|
@ -99,6 +99,10 @@ class Trainer(TrainerIO):
|
|||
|
||||
self.data_parallel = self.data_parallel_device_ids is not None and len(self.data_parallel_device_ids) > 0
|
||||
|
||||
# set the correct cuda visible devices (using pci order)
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(self.data_parallel_device_ids)
|
||||
|
||||
# process info
|
||||
self.proc_rank = 0
|
||||
|
||||
|
|
Loading…
Reference in New Issue