52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
|
from pytorch_lightning import accelerators
|
||
|
import os
|
||
|
|
||
|
|
||
|
class AcceleratorConnector:
|
||
|
|
||
|
def __init__(self, trainer):
|
||
|
self.trainer = trainer
|
||
|
|
||
|
def select_accelerator(self):
|
||
|
# SLURM ddp
|
||
|
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
|
||
|
|
||
|
# torchelastic or general non_slurm ddp
|
||
|
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
|
||
|
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed
|
||
|
|
||
|
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend in ['ddp_cpu', 'ddp_spawn']
|
||
|
|
||
|
# choose the appropriate accelerator backend
|
||
|
if self.trainer.use_ddp2:
|
||
|
accelerator_backend = accelerators.DDP2Backend(self.trainer)
|
||
|
|
||
|
elif use_slurm_ddp:
|
||
|
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='slurm_ddp')
|
||
|
|
||
|
elif use_torchelastic_ddp:
|
||
|
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='torchelastic_ddp')
|
||
|
|
||
|
elif use_ddp_spawn:
|
||
|
accelerator_backend = accelerators.DDPSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)
|
||
|
|
||
|
elif self.trainer.distributed_backend == 'ddp':
|
||
|
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp')
|
||
|
|
||
|
elif self.trainer.use_dp:
|
||
|
accelerator_backend = accelerators.DataParallelBackend(self.trainer)
|
||
|
|
||
|
elif self.trainer.use_horovod:
|
||
|
accelerator_backend = accelerators.HorovodBackend(self.trainer)
|
||
|
|
||
|
elif self.trainer.use_single_gpu:
|
||
|
accelerator_backend = accelerators.GPUBackend(self.trainer)
|
||
|
|
||
|
elif self.trainer.use_tpu:
|
||
|
accelerator_backend = accelerators.TPUBackend(self.trainer)
|
||
|
|
||
|
else:
|
||
|
accelerator_backend = accelerators.CPUBackend(self.trainer)
|
||
|
|
||
|
return accelerator_backend
|