ref: apex plugin (#3502)

* ref: apex plugin

* ref: apex plugin

* ref: apex plugin
This commit is contained in:
William Falcon 2020-09-15 06:02:42 -04:00 committed by GitHub
parent 61b31d94b4
commit 810b445097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 57 additions and 37 deletions

View File

@ -141,3 +141,4 @@ Indices and tables
api/pytorch_lightning.trainer api/pytorch_lightning.trainer
api/pytorch_lightning.utilities api/pytorch_lightning.utilities
api/pytorch_lightning.tuner api/pytorch_lightning.tuner
api/pytorch_lightning.plugins

View File

@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
from pytorch_lightning.plugins.apex import ApexPlugin
try: try:
from hydra.utils import to_absolute_path, get_original_cwd from hydra.utils import to_absolute_path, get_original_cwd
@ -31,17 +32,13 @@ except ImportError:
else: else:
HYDRA_AVAILABLE = True HYDRA_AVAILABLE = True
try:
from apex import amp
except ImportError:
amp = None
class DDP2Backend(DDPBase): class DDP2Backend(DDPBase):
def __init__(self, trainer): def __init__(self, trainer):
super().__init__(trainer) super().__init__(trainer)
self.task_idx = None self.task_idx = None
self.precision_backend = None
def setup(self, model): def setup(self, model):
self._resolve_task_idx() self._resolve_task_idx()

View File

@ -22,6 +22,7 @@ import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
from pytorch_lightning import _logger as log from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.apex import ApexPlugin
try: try:
from hydra.utils import to_absolute_path, get_original_cwd from hydra.utils import to_absolute_path, get_original_cwd
@ -31,16 +32,12 @@ except ImportError:
else: else:
HYDRA_AVAILABLE = True HYDRA_AVAILABLE = True
try:
from apex import amp
except ImportError:
amp = None
class DDPBase(Accelerator): class DDPBase(Accelerator):
def __init__(self, trainer): def __init__(self, trainer):
super().__init__(trainer) super().__init__(trainer)
self.precision_backend = None
def training_step(self, args): def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE: if self.trainer.amp_backend == AMPType.NATIVE:
@ -155,9 +152,8 @@ class DDPBase(Accelerator):
# AMP - # AMP -
# run through amp wrapper before going to distributed DP # run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX: if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.precision_backend = ApexPlugin(self.trainer)
self.trainer.optimizers = optimizers model, optimizers = self.precision_backend._init(model)
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
# device ids change depending on the DDP setup # device ids change depending on the DDP setup
device_ids = self.get_device_ids() device_ids = self.get_device_ids()

View File

@ -20,6 +20,7 @@ from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.plugins.apex import ApexPlugin
try: try:
from apex import amp from apex import amp
@ -32,6 +33,7 @@ class DataParallelBackend(Accelerator):
def __init__(self, trainer): def __init__(self, trainer):
super().__init__(trainer) super().__init__(trainer)
self.model_autocast_original_forward = None self.model_autocast_original_forward = None
self.precision_backend = None
def setup(self, model): def setup(self, model):
# call setup after the ddp process has connected # call setup after the ddp process has connected
@ -89,8 +91,8 @@ class DataParallelBackend(Accelerator):
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.' f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
f' We recommend you switch to ddp if you want to use amp') f' We recommend you switch to ddp if you want to use amp')
else: else:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.precision_backend = ApexPlugin(self.trainer)
self.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers) model, optimizers = self.precision_backend._init(model)
return model return model

View File

@ -13,14 +13,9 @@
# limitations under the License. # limitations under the License.
import torch import torch
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.plugins.apex import ApexPlugin
try:
from apex import amp
except ImportError:
amp = None
class GPUBackend(Accelerator): class GPUBackend(Accelerator):
@ -28,6 +23,7 @@ class GPUBackend(Accelerator):
def __init__(self, trainer): def __init__(self, trainer):
super().__init__(trainer) super().__init__(trainer)
self.precision_backend = None
def setup(self, model): def setup(self, model):
@ -45,7 +41,8 @@ class GPUBackend(Accelerator):
self.trainer.optimizer_frequencies = optimizer_frequencies self.trainer.optimizer_frequencies = optimizer_frequencies
if self.trainer.amp_backend == AMPType.APEX: if self.trainer.amp_backend == AMPType.APEX:
model = self._setup_nvidia_apex(model) self.precision_backend = ApexPlugin(self.trainer)
model, optimizers = self.precision_backend._init(model)
self.trainer.model = model self.trainer.model = model
@ -117,9 +114,3 @@ class GPUBackend(Accelerator):
# be referenced from and if there are multiple optimizers the batch will # be referenced from and if there are multiple optimizers the batch will
# wind up copying it to the same device repeatedly. # wind up copying it to the same device repeatedly.
return self.batch_to_device(batch, gpu_id) return self.batch_to_device(batch, gpu_id)
def _setup_nvidia_apex(self, model: LightningModule):
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
return model

View File

@ -18,12 +18,7 @@ from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from pytorch_lightning.plugins.apex import ApexPlugin
try:
from apex import amp
except ImportError:
amp = None
try: try:
import horovod.torch as hvd import horovod.torch as hvd
@ -38,6 +33,7 @@ class HorovodBackend(Accelerator):
def __init__(self, trainer): def __init__(self, trainer):
super().__init__(trainer) super().__init__(trainer)
self.precision_backend = None
def setup(self, model): def setup(self, model):
# call setup after the ddp process has connected # call setup after the ddp process has connected
@ -88,9 +84,8 @@ class HorovodBackend(Accelerator):
] ]
if self.trainer.amp_backend == AMPType.APEX: if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.precision_backend = ApexPlugin(self.trainer)
self.trainer.optimizers = optimizers model, optimizers = self.precision_backend._init(model)
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
# Update logger rank info from Horovod to avoid race conditions from different ranks # Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations. # creating directories / writing files in the same locations.

View File

View File

@ -0,0 +1,38 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from apex import amp
except ImportError:
amp = None
class ApexPlugin:
def __init__(self, trainer):
self.trainer = trainer
def _init(self, model):
model, optimizers = self.configure_apex(model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
return model, optimizers
def configure_apex(self, model, optimizers, amp_level):
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers
def training_step(self, fx, args):
output = fx(args)
return output