ref: apex plugin (#3502)
* ref: apex plugin * ref: apex plugin * ref: apex plugin
This commit is contained in:
parent
61b31d94b4
commit
810b445097
|
@ -141,3 +141,4 @@ Indices and tables
|
|||
api/pytorch_lightning.trainer
|
||||
api/pytorch_lightning.utilities
|
||||
api/pytorch_lightning.tuner
|
||||
api/pytorch_lightning.plugins
|
||||
|
|
|
@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
|
||||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
|
||||
try:
|
||||
from hydra.utils import to_absolute_path, get_original_cwd
|
||||
|
@ -31,17 +32,13 @@ except ImportError:
|
|||
else:
|
||||
HYDRA_AVAILABLE = True
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
|
||||
class DDP2Backend(DDPBase):
|
||||
|
||||
def __init__(self, trainer):
|
||||
super().__init__(trainer)
|
||||
self.task_idx = None
|
||||
self.precision_backend = None
|
||||
|
||||
def setup(self, model):
|
||||
self._resolve_task_idx()
|
||||
|
|
|
@ -22,6 +22,7 @@ import torch.distributed as dist
|
|||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
|
||||
try:
|
||||
from hydra.utils import to_absolute_path, get_original_cwd
|
||||
|
@ -31,16 +32,12 @@ except ImportError:
|
|||
else:
|
||||
HYDRA_AVAILABLE = True
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
|
||||
class DDPBase(Accelerator):
|
||||
|
||||
def __init__(self, trainer):
|
||||
super().__init__(trainer)
|
||||
self.precision_backend = None
|
||||
|
||||
def training_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
|
@ -155,9 +152,8 @@ class DDPBase(Accelerator):
|
|||
# AMP -
|
||||
# run through amp wrapper before going to distributed DP
|
||||
if self.trainer.amp_backend == AMPType.APEX:
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
self.precision_backend = ApexPlugin(self.trainer)
|
||||
model, optimizers = self.precision_backend._init(model)
|
||||
|
||||
# device ids change depending on the DDP setup
|
||||
device_ids = self.get_device_ids()
|
||||
|
|
|
@ -20,6 +20,7 @@ from pytorch_lightning.utilities import AMPType
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -32,6 +33,7 @@ class DataParallelBackend(Accelerator):
|
|||
def __init__(self, trainer):
|
||||
super().__init__(trainer)
|
||||
self.model_autocast_original_forward = None
|
||||
self.precision_backend = None
|
||||
|
||||
def setup(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
|
@ -89,8 +91,8 @@ class DataParallelBackend(Accelerator):
|
|||
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
|
||||
f' We recommend you switch to ddp if you want to use amp')
|
||||
else:
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
|
||||
self.precision_backend = ApexPlugin(self.trainer)
|
||||
model, optimizers = self.precision_backend._init(model)
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
@ -13,14 +13,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
|
||||
|
||||
class GPUBackend(Accelerator):
|
||||
|
@ -28,6 +23,7 @@ class GPUBackend(Accelerator):
|
|||
|
||||
def __init__(self, trainer):
|
||||
super().__init__(trainer)
|
||||
self.precision_backend = None
|
||||
|
||||
def setup(self, model):
|
||||
|
||||
|
@ -45,7 +41,8 @@ class GPUBackend(Accelerator):
|
|||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||
|
||||
if self.trainer.amp_backend == AMPType.APEX:
|
||||
model = self._setup_nvidia_apex(model)
|
||||
self.precision_backend = ApexPlugin(self.trainer)
|
||||
model, optimizers = self.precision_backend._init(model)
|
||||
|
||||
self.trainer.model = model
|
||||
|
||||
|
@ -117,9 +114,3 @@ class GPUBackend(Accelerator):
|
|||
# be referenced from and if there are multiple optimizers the batch will
|
||||
# wind up copying it to the same device repeatedly.
|
||||
return self.batch_to_device(batch, gpu_id)
|
||||
|
||||
def _setup_nvidia_apex(self, model: LightningModule):
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
return model
|
||||
|
|
|
@ -18,12 +18,7 @@ from pytorch_lightning.utilities import AMPType
|
|||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
|
@ -38,6 +33,7 @@ class HorovodBackend(Accelerator):
|
|||
|
||||
def __init__(self, trainer):
|
||||
super().__init__(trainer)
|
||||
self.precision_backend = None
|
||||
|
||||
def setup(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
|
@ -88,9 +84,8 @@ class HorovodBackend(Accelerator):
|
|||
]
|
||||
|
||||
if self.trainer.amp_backend == AMPType.APEX:
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
self.precision_backend = ApexPlugin(self.trainer)
|
||||
model, optimizers = self.precision_backend._init(model)
|
||||
|
||||
# Update logger rank info from Horovod to avoid race conditions from different ranks
|
||||
# creating directories / writing files in the same locations.
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
amp = None
|
||||
|
||||
|
||||
class ApexPlugin:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def _init(self, model):
|
||||
model, optimizers = self.configure_apex(model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
return model, optimizers
|
||||
|
||||
def configure_apex(self, model, optimizers, amp_level):
|
||||
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
|
||||
return model, optimizers
|
||||
|
||||
def training_step(self, fx, args):
|
||||
output = fx(args)
|
||||
return output
|
Loading…
Reference in New Issue