ref: apex plugin (#3502)
* ref: apex plugin * ref: apex plugin * ref: apex plugin
This commit is contained in:
parent
61b31d94b4
commit
810b445097
|
@ -141,3 +141,4 @@ Indices and tables
|
||||||
api/pytorch_lightning.trainer
|
api/pytorch_lightning.trainer
|
||||||
api/pytorch_lightning.utilities
|
api/pytorch_lightning.utilities
|
||||||
api/pytorch_lightning.tuner
|
api/pytorch_lightning.tuner
|
||||||
|
api/pytorch_lightning.plugins
|
||||||
|
|
|
@ -22,6 +22,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.core.step_result import Result
|
from pytorch_lightning.core.step_result import Result
|
||||||
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
|
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
|
||||||
|
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from hydra.utils import to_absolute_path, get_original_cwd
|
from hydra.utils import to_absolute_path, get_original_cwd
|
||||||
|
@ -31,17 +32,13 @@ except ImportError:
|
||||||
else:
|
else:
|
||||||
HYDRA_AVAILABLE = True
|
HYDRA_AVAILABLE = True
|
||||||
|
|
||||||
try:
|
|
||||||
from apex import amp
|
|
||||||
except ImportError:
|
|
||||||
amp = None
|
|
||||||
|
|
||||||
|
|
||||||
class DDP2Backend(DDPBase):
|
class DDP2Backend(DDPBase):
|
||||||
|
|
||||||
def __init__(self, trainer):
|
def __init__(self, trainer):
|
||||||
super().__init__(trainer)
|
super().__init__(trainer)
|
||||||
self.task_idx = None
|
self.task_idx = None
|
||||||
|
self.precision_backend = None
|
||||||
|
|
||||||
def setup(self, model):
|
def setup(self, model):
|
||||||
self._resolve_task_idx()
|
self._resolve_task_idx()
|
||||||
|
|
|
@ -22,6 +22,7 @@ import torch.distributed as dist
|
||||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
|
||||||
from pytorch_lightning import _logger as log
|
from pytorch_lightning import _logger as log
|
||||||
|
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from hydra.utils import to_absolute_path, get_original_cwd
|
from hydra.utils import to_absolute_path, get_original_cwd
|
||||||
|
@ -31,16 +32,12 @@ except ImportError:
|
||||||
else:
|
else:
|
||||||
HYDRA_AVAILABLE = True
|
HYDRA_AVAILABLE = True
|
||||||
|
|
||||||
try:
|
|
||||||
from apex import amp
|
|
||||||
except ImportError:
|
|
||||||
amp = None
|
|
||||||
|
|
||||||
|
|
||||||
class DDPBase(Accelerator):
|
class DDPBase(Accelerator):
|
||||||
|
|
||||||
def __init__(self, trainer):
|
def __init__(self, trainer):
|
||||||
super().__init__(trainer)
|
super().__init__(trainer)
|
||||||
|
self.precision_backend = None
|
||||||
|
|
||||||
def training_step(self, args):
|
def training_step(self, args):
|
||||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||||
|
@ -155,9 +152,8 @@ class DDPBase(Accelerator):
|
||||||
# AMP -
|
# AMP -
|
||||||
# run through amp wrapper before going to distributed DP
|
# run through amp wrapper before going to distributed DP
|
||||||
if self.trainer.amp_backend == AMPType.APEX:
|
if self.trainer.amp_backend == AMPType.APEX:
|
||||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
self.precision_backend = ApexPlugin(self.trainer)
|
||||||
self.trainer.optimizers = optimizers
|
model, optimizers = self.precision_backend._init(model)
|
||||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
|
||||||
|
|
||||||
# device ids change depending on the DDP setup
|
# device ids change depending on the DDP setup
|
||||||
device_ids = self.get_device_ids()
|
device_ids = self.get_device_ids()
|
||||||
|
|
|
@ -20,6 +20,7 @@ from pytorch_lightning.utilities import AMPType
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.core.step_result import Result
|
from pytorch_lightning.core.step_result import Result
|
||||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||||
|
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
@ -32,6 +33,7 @@ class DataParallelBackend(Accelerator):
|
||||||
def __init__(self, trainer):
|
def __init__(self, trainer):
|
||||||
super().__init__(trainer)
|
super().__init__(trainer)
|
||||||
self.model_autocast_original_forward = None
|
self.model_autocast_original_forward = None
|
||||||
|
self.precision_backend = None
|
||||||
|
|
||||||
def setup(self, model):
|
def setup(self, model):
|
||||||
# call setup after the ddp process has connected
|
# call setup after the ddp process has connected
|
||||||
|
@ -89,8 +91,8 @@ class DataParallelBackend(Accelerator):
|
||||||
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
|
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
|
||||||
f' We recommend you switch to ddp if you want to use amp')
|
f' We recommend you switch to ddp if you want to use amp')
|
||||||
else:
|
else:
|
||||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
self.precision_backend = ApexPlugin(self.trainer)
|
||||||
self.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
|
model, optimizers = self.precision_backend._init(model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -13,14 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.core import LightningModule
|
|
||||||
from pytorch_lightning.utilities import AMPType
|
from pytorch_lightning.utilities import AMPType
|
||||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||||
|
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||||
try:
|
|
||||||
from apex import amp
|
|
||||||
except ImportError:
|
|
||||||
amp = None
|
|
||||||
|
|
||||||
|
|
||||||
class GPUBackend(Accelerator):
|
class GPUBackend(Accelerator):
|
||||||
|
@ -28,6 +23,7 @@ class GPUBackend(Accelerator):
|
||||||
|
|
||||||
def __init__(self, trainer):
|
def __init__(self, trainer):
|
||||||
super().__init__(trainer)
|
super().__init__(trainer)
|
||||||
|
self.precision_backend = None
|
||||||
|
|
||||||
def setup(self, model):
|
def setup(self, model):
|
||||||
|
|
||||||
|
@ -45,7 +41,8 @@ class GPUBackend(Accelerator):
|
||||||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||||
|
|
||||||
if self.trainer.amp_backend == AMPType.APEX:
|
if self.trainer.amp_backend == AMPType.APEX:
|
||||||
model = self._setup_nvidia_apex(model)
|
self.precision_backend = ApexPlugin(self.trainer)
|
||||||
|
model, optimizers = self.precision_backend._init(model)
|
||||||
|
|
||||||
self.trainer.model = model
|
self.trainer.model = model
|
||||||
|
|
||||||
|
@ -117,9 +114,3 @@ class GPUBackend(Accelerator):
|
||||||
# be referenced from and if there are multiple optimizers the batch will
|
# be referenced from and if there are multiple optimizers the batch will
|
||||||
# wind up copying it to the same device repeatedly.
|
# wind up copying it to the same device repeatedly.
|
||||||
return self.batch_to_device(batch, gpu_id)
|
return self.batch_to_device(batch, gpu_id)
|
||||||
|
|
||||||
def _setup_nvidia_apex(self, model: LightningModule):
|
|
||||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
|
||||||
self.trainer.optimizers = optimizers
|
|
||||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
|
||||||
return model
|
|
||||||
|
|
|
@ -18,12 +18,7 @@ from pytorch_lightning.utilities import AMPType
|
||||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||||
try:
|
|
||||||
from apex import amp
|
|
||||||
except ImportError:
|
|
||||||
amp = None
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import horovod.torch as hvd
|
import horovod.torch as hvd
|
||||||
|
@ -38,6 +33,7 @@ class HorovodBackend(Accelerator):
|
||||||
|
|
||||||
def __init__(self, trainer):
|
def __init__(self, trainer):
|
||||||
super().__init__(trainer)
|
super().__init__(trainer)
|
||||||
|
self.precision_backend = None
|
||||||
|
|
||||||
def setup(self, model):
|
def setup(self, model):
|
||||||
# call setup after the ddp process has connected
|
# call setup after the ddp process has connected
|
||||||
|
@ -88,9 +84,8 @@ class HorovodBackend(Accelerator):
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.trainer.amp_backend == AMPType.APEX:
|
if self.trainer.amp_backend == AMPType.APEX:
|
||||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
self.precision_backend = ApexPlugin(self.trainer)
|
||||||
self.trainer.optimizers = optimizers
|
model, optimizers = self.precision_backend._init(model)
|
||||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
|
||||||
|
|
||||||
# Update logger rank info from Horovod to avoid race conditions from different ranks
|
# Update logger rank info from Horovod to avoid race conditions from different ranks
|
||||||
# creating directories / writing files in the same locations.
|
# creating directories / writing files in the same locations.
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright The PyTorch Lightning team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
except ImportError:
|
||||||
|
amp = None
|
||||||
|
|
||||||
|
|
||||||
|
class ApexPlugin:
|
||||||
|
|
||||||
|
def __init__(self, trainer):
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
def _init(self, model):
|
||||||
|
model, optimizers = self.configure_apex(model, self.trainer.optimizers, self.trainer.amp_level)
|
||||||
|
self.trainer.optimizers = optimizers
|
||||||
|
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||||
|
return model, optimizers
|
||||||
|
|
||||||
|
def configure_apex(self, model, optimizers, amp_level):
|
||||||
|
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
|
||||||
|
return model, optimizers
|
||||||
|
|
||||||
|
def training_step(self, fx, args):
|
||||||
|
output = fx(args)
|
||||||
|
return output
|
Loading…
Reference in New Issue