refactor 3/n (#2709)
* reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator * reactor into gpu accelerator
This commit is contained in:
parent
b34217e410
commit
4dbd761a1c
|
@ -7,7 +7,7 @@
|
|||
"pytorch_lightning/__init__.py",
|
||||
"pytorch_lightning/callbacks",
|
||||
"pytorch_lightning/core",
|
||||
"pytorch_lightning/accelerators",
|
||||
"pytorch_lightning/accelerator_backends",
|
||||
"pytorch_lightning/loggers",
|
||||
"pytorch_lightning/logging",
|
||||
"pytorch_lightning/metrics",
|
||||
|
|
|
@ -138,7 +138,7 @@ language = None
|
|||
exclude_patterns = [
|
||||
'api/pytorch_lightning.rst',
|
||||
'api/pl_examples.*',
|
||||
'api/pytorch_lightning.accelerators.*',
|
||||
'api/pytorch_lightning.accelerator_backends.*',
|
||||
'api/modules.rst',
|
||||
'PULL_REQUEST_TEMPLATE.md',
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from pytorch_lightning.accelerator_backends.gpu_backend import GPUBackend
|
||||
from pytorch_lightning.accelerator_backends.tpu_backend import TPUBackend
|
||||
from pytorch_lightning.accelerator_backends.dp_backend import DataParallelBackend
|
|
@ -0,0 +1,117 @@
|
|||
import torch
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
|
||||
from torch import optim
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
|
||||
|
||||
class DataParallelBackend(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.model_autocast_original_forward = None
|
||||
|
||||
def setup(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
if not self.trainer.testing:
|
||||
self.trainer.setup('fit')
|
||||
model.setup('fit')
|
||||
|
||||
# put model on correct device
|
||||
model.cuda(self.trainer.root_gpu)
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.lr_schedulers = lr_schedulers
|
||||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||
|
||||
# hack forward to do autocast for the user
|
||||
self.model_autocast_original_forward = model.forward
|
||||
|
||||
# init half precision
|
||||
if self.trainer.use_amp:
|
||||
model = self.__init_half_precision(model)
|
||||
|
||||
# init torch data parallel
|
||||
model = self.__init_torch_data_parallel(model)
|
||||
|
||||
self.trainer.model = model
|
||||
|
||||
def __init_torch_data_parallel(self, model):
|
||||
# create list of device ids
|
||||
device_ids = self.trainer.data_parallel_device_ids
|
||||
if isinstance(device_ids, int):
|
||||
device_ids = list(range(device_ids))
|
||||
|
||||
# set dp device
|
||||
torch.cuda.set_device(self.trainer.root_gpu)
|
||||
model = LightningDataParallel(model, device_ids=device_ids)
|
||||
return model
|
||||
|
||||
def __init_half_precision(self, model):
|
||||
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
|
||||
if native_amp_available:
|
||||
self.__init_native_amp(model)
|
||||
else:
|
||||
model = self.__init_nvidia_apex(model)
|
||||
return model
|
||||
|
||||
def __init_native_amp(self, model):
|
||||
model.forward = torch.cuda.amp.autocast()(model.forward)
|
||||
|
||||
def __init_nvidia_apex(self, model):
|
||||
# check for this bug (amp + dp + !01 doesn't work)
|
||||
# https://github.com/NVIDIA/apex/issues/227
|
||||
if self.trainer.amp_level == 'O2':
|
||||
raise MisconfigurationException(
|
||||
f'Amp level {self.trainer.amp_level} with DataParallel is not supported.'
|
||||
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
|
||||
f' We recommend you switch to ddp if you want to use amp')
|
||||
else:
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
|
||||
|
||||
return model
|
||||
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
results = self.trainer.run_pretrain_routine(model)
|
||||
return results
|
||||
|
||||
def teardown(self):
|
||||
|
||||
# replace the original fwd function
|
||||
self.trainer.model.forward = self.model_autocast_original_forward
|
||||
|
||||
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
|
||||
"""
|
||||
Reinitialize optimizer.step properties added by schedulers
|
||||
"""
|
||||
for scheduler in schedulers:
|
||||
scheduler = scheduler['scheduler']
|
||||
|
||||
for optimizer in optimizers:
|
||||
# check that we dont mix users optimizers and schedulers
|
||||
if scheduler.optimizer == optimizer:
|
||||
# Find the mro belonging to the base lr scheduler class
|
||||
for i, mro in enumerate(scheduler.__class__.__mro__):
|
||||
is_regular_scheduler = optim.lr_scheduler._LRScheduler
|
||||
is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau
|
||||
if is_regular_scheduler or is_lr_reduce_on_plateau:
|
||||
idx = i
|
||||
state = scheduler.state_dict()
|
||||
else:
|
||||
state = None
|
||||
|
||||
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
|
||||
if state is not None:
|
||||
scheduler.load_state_dict(state)
|
|
@ -14,13 +14,21 @@
|
|||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
|
||||
class GPUAccelerator(object):
|
||||
|
||||
class GPUBackend(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def setup(self, model):
|
||||
|
||||
# call setup
|
||||
if not self.trainer.testing:
|
||||
self.trainer.setup('fit')
|
||||
|
@ -38,9 +46,15 @@ class GPUAccelerator(object):
|
|||
# TODO: remove with dropping NVIDIA AMP support
|
||||
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
if self.trainer.use_amp and not native_amp_available:
|
||||
self._setup_nvidia_apex(model)
|
||||
model = self._setup_nvidia_apex(model)
|
||||
return model
|
||||
|
||||
def train(self, model):
|
||||
results = self.trainer.run_pretrain_routine(model)
|
||||
return results
|
||||
|
||||
def _setup_nvidia_apex(self, model):
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
return model
|
|
@ -28,7 +28,7 @@ else:
|
|||
XLA_AVAILABLE = True
|
||||
|
||||
|
||||
class TPUAccelerator(object):
|
||||
class TPUBackend(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
|
@ -1,2 +0,0 @@
|
|||
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
|
||||
from pytorch_lightning.accelerators.tpu_accelerator import TPUAccelerator
|
|
@ -179,52 +179,6 @@ class TrainerDPMixin(ABC):
|
|||
return model.transfer_batch_to_device(batch, device)
|
||||
return move_data_to_device(batch, device)
|
||||
|
||||
def dp_train(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
if not self.testing:
|
||||
self.setup('fit')
|
||||
model.setup('fit')
|
||||
|
||||
model.cuda(self.root_gpu)
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
|
||||
|
||||
# hack forward to do autocast for the user
|
||||
model_autocast_original_forward = model.forward
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||
# wrap the user's forward in autocast and give it back at the end
|
||||
model.forward = torch.cuda.amp.autocast()(model.forward)
|
||||
|
||||
# TODO: remove with dropping NVIDIA AMP support
|
||||
# check for this bug (amp + dp + !01 doesn't work)
|
||||
# https://github.com/NVIDIA/apex/issues/227
|
||||
if self.use_dp and self.use_amp and not NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||
if self.amp_level == 'O2':
|
||||
raise MisconfigurationException(
|
||||
f'Amp level {self.amp_level} with DataParallel is not supported.'
|
||||
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
|
||||
f' We recommend you switch to ddp if you want to use amp')
|
||||
else:
|
||||
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
|
||||
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)
|
||||
|
||||
# create list of device ids
|
||||
device_ids = self.data_parallel_device_ids
|
||||
if isinstance(device_ids, int):
|
||||
device_ids = list(range(device_ids))
|
||||
|
||||
# set dp device
|
||||
torch.cuda.set_device(self.root_gpu)
|
||||
|
||||
model = LightningDataParallel(model, device_ids=device_ids)
|
||||
|
||||
result = self.run_pretrain_routine(model)
|
||||
model.forward = model_autocast_original_forward
|
||||
|
||||
return result
|
||||
|
||||
def horovod_train(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
if not self.testing:
|
||||
|
|
|
@ -51,7 +51,7 @@ from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only,
|
|||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
|
||||
from pytorch_lightning.accelerators import GPUAccelerator, TPUAccelerator
|
||||
from pytorch_lightning.accelerator_backends import GPUBackend, TPUBackend, DataParallelBackend
|
||||
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
|
@ -661,7 +661,7 @@ class Trainer(
|
|||
# tracks internal state for debugging
|
||||
self.dev_debugger = InternalDebugger(self)
|
||||
self.config_validator = ConfigValidator(self)
|
||||
self.accelerator = None
|
||||
self.accelerator_backend = None
|
||||
|
||||
# Callback system
|
||||
self.on_init_end()
|
||||
|
@ -1064,24 +1064,25 @@ class Trainer(
|
|||
self.set_random_port()
|
||||
results = self.spawn_ddp_children(model)
|
||||
|
||||
# 1 gpu or dp option triggers training using DP module
|
||||
# easier to avoid NCCL issues
|
||||
elif self.use_dp:
|
||||
results = self.dp_train(model)
|
||||
self.accelerator_backend = DataParallelBackend(self)
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
elif self.use_horovod:
|
||||
results = self.horovod_train(model)
|
||||
|
||||
elif self.single_gpu:
|
||||
self.accelerator = GPUAccelerator(self)
|
||||
self.accelerator.setup(model)
|
||||
results = self.run_pretrain_routine(model)
|
||||
self.accelerator_backend = GPUBackend(self)
|
||||
model = self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train(model)
|
||||
|
||||
elif self.use_tpu:
|
||||
self.accelerator = TPUAccelerator(self)
|
||||
self.accelerator.setup()
|
||||
self.accelerator.train(model)
|
||||
self.accelerator.teardown()
|
||||
self.accelerator_backend = TPUBackend(self)
|
||||
self.accelerator_backend.setup()
|
||||
self.accelerator_backend.train(model)
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
# ON CPU
|
||||
else:
|
||||
|
|
|
@ -38,6 +38,9 @@ def test_single_gpu_test(tmpdir):
|
|||
def test_dp_test(tmpdir):
|
||||
tutils.set_random_master_port()
|
||||
|
||||
import os
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
|
||||
model = EvalModelTemplate()
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=os.getcwd(),
|
||||
|
|
Loading…
Reference in New Issue