ref: moved accelerator router (#3309)
* ref: moved accelerator * ref: moved accelerator * ref: moved accelerator * ref: moved accelerator
This commit is contained in:
parent
b66ce88f0d
commit
0a119403d6
|
@ -0,0 +1,51 @@
|
|||
from pytorch_lightning import accelerators
|
||||
import os
|
||||
|
||||
|
||||
class AcceleratorConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def select_accelerator(self):
|
||||
# SLURM ddp
|
||||
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
|
||||
|
||||
# torchelastic or general non_slurm ddp
|
||||
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
|
||||
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed
|
||||
|
||||
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend in ['ddp_cpu', 'ddp_spawn']
|
||||
|
||||
# choose the appropriate accelerator backend
|
||||
if self.trainer.use_ddp2:
|
||||
accelerator_backend = accelerators.DDP2Backend(self.trainer)
|
||||
|
||||
elif use_slurm_ddp:
|
||||
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='slurm_ddp')
|
||||
|
||||
elif use_torchelastic_ddp:
|
||||
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='torchelastic_ddp')
|
||||
|
||||
elif use_ddp_spawn:
|
||||
accelerator_backend = accelerators.DDPSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)
|
||||
|
||||
elif self.trainer.distributed_backend == 'ddp':
|
||||
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp')
|
||||
|
||||
elif self.trainer.use_dp:
|
||||
accelerator_backend = accelerators.DataParallelBackend(self.trainer)
|
||||
|
||||
elif self.trainer.use_horovod:
|
||||
accelerator_backend = accelerators.HorovodBackend(self.trainer)
|
||||
|
||||
elif self.trainer.use_single_gpu:
|
||||
accelerator_backend = accelerators.GPUBackend(self.trainer)
|
||||
|
||||
elif self.trainer.use_tpu:
|
||||
accelerator_backend = accelerators.TPUBackend(self.trainer)
|
||||
|
||||
else:
|
||||
accelerator_backend = accelerators.CPUBackend(self.trainer)
|
||||
|
||||
return accelerator_backend
|
|
@ -22,8 +22,6 @@ import torch
|
|||
import torch.distributed as torch_distrib
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.accelerators import (
|
||||
GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend, HorovodBackend)
|
||||
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
@ -56,6 +54,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.cloud_io import is_remote_path
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.data_connector import DataConnector
|
||||
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
|
||||
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
# warnings to ignore in trainer
|
||||
|
@ -610,6 +610,7 @@ class Trainer(
|
|||
self.dev_debugger = InternalDebugger(self)
|
||||
self.config_validator = ConfigValidator(self)
|
||||
self.data_connector = DataConnector(self)
|
||||
self.accelerator_connector = AcceleratorConnector(self)
|
||||
self.accelerator_backend = None
|
||||
|
||||
# loops
|
||||
|
@ -1022,7 +1023,7 @@ class Trainer(
|
|||
# -------------------------
|
||||
# TRAIN
|
||||
# -------------------------
|
||||
self.accelerator_backend = self.select_accelerator()
|
||||
self.accelerator_backend = self.accelerator_connector.select_accelerator()
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
@ -1056,49 +1057,6 @@ class Trainer(
|
|||
# check that model is configured correctly
|
||||
self.config_validator.verify_loop_configurations(model)
|
||||
|
||||
def select_accelerator(self):
|
||||
# SLURM ddp
|
||||
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
|
||||
|
||||
# torchelastic or general non_slurm ddp
|
||||
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
|
||||
use_torchelastic_ddp = self.use_ddp and te_flags_passed
|
||||
|
||||
use_ddp_spawn = self.use_ddp and self.distributed_backend in ['ddp_cpu', 'ddp_spawn']
|
||||
|
||||
# choose the appropriate accelerator backend
|
||||
if self.use_ddp2:
|
||||
accelerator_backend = DDP2Backend(self)
|
||||
|
||||
elif use_slurm_ddp:
|
||||
accelerator_backend = DDPBackend(self, mode='slurm_ddp')
|
||||
|
||||
elif use_torchelastic_ddp:
|
||||
accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
|
||||
|
||||
elif use_ddp_spawn:
|
||||
accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
|
||||
|
||||
elif self.distributed_backend == 'ddp':
|
||||
accelerator_backend = DDPBackend(self, mode='ddp')
|
||||
|
||||
elif self.use_dp:
|
||||
accelerator_backend = DataParallelBackend(self)
|
||||
|
||||
elif self.use_horovod:
|
||||
accelerator_backend = HorovodBackend(self)
|
||||
|
||||
elif self.use_single_gpu:
|
||||
accelerator_backend = GPUBackend(self)
|
||||
|
||||
elif self.use_tpu:
|
||||
accelerator_backend = TPUBackend(self)
|
||||
|
||||
else:
|
||||
accelerator_backend = CPUBackend(self)
|
||||
|
||||
return accelerator_backend
|
||||
|
||||
def setup_training(self, model: LightningModule):
|
||||
"""Sanity check a few things before starting actual training.
|
||||
|
||||
|
|
Loading…
Reference in New Issue