ref: moved accelerator router (#3309)

* ref: moved accelerator

* ref: moved accelerator

* ref: moved accelerator

* ref: moved accelerator
This commit is contained in:
William Falcon 2020-09-01 15:48:28 -04:00 committed by GitHub
parent b66ce88f0d
commit 0a119403d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 46 deletions

View File

@ -0,0 +1,51 @@
from pytorch_lightning import accelerators
import os
class AcceleratorConnector:
def __init__(self, trainer):
self.trainer = trainer
def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
# torchelastic or general non_slurm ddp
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend in ['ddp_cpu', 'ddp_spawn']
# choose the appropriate accelerator backend
if self.trainer.use_ddp2:
accelerator_backend = accelerators.DDP2Backend(self.trainer)
elif use_slurm_ddp:
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='slurm_ddp')
elif use_torchelastic_ddp:
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='torchelastic_ddp')
elif use_ddp_spawn:
accelerator_backend = accelerators.DDPSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)
elif self.trainer.distributed_backend == 'ddp':
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp')
elif self.trainer.use_dp:
accelerator_backend = accelerators.DataParallelBackend(self.trainer)
elif self.trainer.use_horovod:
accelerator_backend = accelerators.HorovodBackend(self.trainer)
elif self.trainer.use_single_gpu:
accelerator_backend = accelerators.GPUBackend(self.trainer)
elif self.trainer.use_tpu:
accelerator_backend = accelerators.TPUBackend(self.trainer)
else:
accelerator_backend = accelerators.CPUBackend(self.trainer)
return accelerator_backend

View File

@ -22,8 +22,6 @@ import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader
from pytorch_lightning.accelerators import (
GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend, HorovodBackend)
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
@ -56,6 +54,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities.model_utils import is_overridden
# warnings to ignore in trainer
@ -610,6 +610,7 @@ class Trainer(
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.accelerator_connector = AcceleratorConnector(self)
self.accelerator_backend = None
# loops
@ -1022,7 +1023,7 @@ class Trainer(
# -------------------------
# TRAIN
# -------------------------
self.accelerator_backend = self.select_accelerator()
self.accelerator_backend = self.accelerator_connector.select_accelerator()
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
@ -1056,49 +1057,6 @@ class Trainer(
# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
# torchelastic or general non_slurm ddp
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
use_torchelastic_ddp = self.use_ddp and te_flags_passed
use_ddp_spawn = self.use_ddp and self.distributed_backend in ['ddp_cpu', 'ddp_spawn']
# choose the appropriate accelerator backend
if self.use_ddp2:
accelerator_backend = DDP2Backend(self)
elif use_slurm_ddp:
accelerator_backend = DDPBackend(self, mode='slurm_ddp')
elif use_torchelastic_ddp:
accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
elif use_ddp_spawn:
accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
elif self.distributed_backend == 'ddp':
accelerator_backend = DDPBackend(self, mode='ddp')
elif self.use_dp:
accelerator_backend = DataParallelBackend(self)
elif self.use_horovod:
accelerator_backend = HorovodBackend(self)
elif self.use_single_gpu:
accelerator_backend = GPUBackend(self)
elif self.use_tpu:
accelerator_backend = TPUBackend(self)
else:
accelerator_backend = CPUBackend(self)
return accelerator_backend
def setup_training(self, model: LightningModule):
"""Sanity check a few things before starting actual training.