allow using apex with any PT version (#2865)
* wip * setup * type * name * wip * docs * imports * fix if * fix if * use_amp * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * fix tests * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * fix tests * todos Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
fed0ac838b
commit
a6e7aa7796
|
@ -39,7 +39,6 @@ RUN apt-get update && \
|
|||
&& \
|
||||
|
||||
# Install AMP
|
||||
# TODO: skip this instrall for PT >= 1.6
|
||||
bash install_AMP.sh && \
|
||||
# Install all requirements
|
||||
pip install -r requirements.txt && \
|
||||
|
|
|
@ -22,7 +22,7 @@ class CPUBackend(object):
|
|||
|
||||
def setup(self, model):
|
||||
# run through amp wrapper
|
||||
if self.trainer.use_amp:
|
||||
if self.trainer.amp_type:
|
||||
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
|
||||
|
||||
# call setup after the ddp process has connected
|
||||
|
|
|
@ -17,7 +17,7 @@ import os
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -32,9 +32,7 @@ else:
|
|||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
|
||||
class DDP2Backend(object):
|
||||
|
@ -135,10 +133,8 @@ class DDP2Backend(object):
|
|||
# set model properties before going into wrapper
|
||||
self.trainer.copy_trainer_model_properties(model)
|
||||
|
||||
# AMP
|
||||
# run through amp wrapper before going to distributed DP
|
||||
# TODO: remove with dropping NVIDIA AMP support
|
||||
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
|
||||
# AMP - run through amp wrapper before going to distributed DP
|
||||
if self.trainer.amp_type == AMPType.APEX:
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
|
|
|
@ -23,7 +23,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
try:
|
||||
|
@ -37,9 +37,7 @@ else:
|
|||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
|
||||
class DDPBackend(object):
|
||||
|
@ -202,10 +200,8 @@ class DDPBackend(object):
|
|||
# set model properties before going into wrapper
|
||||
self.trainer.copy_trainer_model_properties(model)
|
||||
|
||||
# AMP
|
||||
# run through amp wrapper before going to distributed DP
|
||||
# TODO: remove with dropping NVIDIA AMP support
|
||||
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
|
||||
# AMP - run through amp wrapper before going to distributed DP
|
||||
if self.trainer.amp_type == AMPType.APEX:
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
|
|
|
@ -16,14 +16,13 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
|
||||
class DDPSpawnBackend(object):
|
||||
|
@ -133,11 +132,9 @@ class DDPSpawnBackend(object):
|
|||
# set model properties before going into wrapper
|
||||
self.trainer.copy_trainer_model_properties(model)
|
||||
|
||||
# AMP
|
||||
# AMP -
|
||||
# run through amp wrapper before going to distributed DP
|
||||
# TODO: remove with dropping NVIDIA AMP support
|
||||
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
if self.trainer.use_amp and not native_amp_available:
|
||||
if self.trainer.amp_type == AMPType.APEX:
|
||||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
|
||||
self.trainer.optimizers = optimizers
|
||||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
|
||||
|
|
|
@ -16,14 +16,13 @@ import torch
|
|||
from torch import optim
|
||||
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
|
||||
class DataParallelBackend(object):
|
||||
|
@ -50,7 +49,7 @@ class DataParallelBackend(object):
|
|||
self.model_autocast_original_forward = model.forward
|
||||
|
||||
# init half precision
|
||||
if self.trainer.use_amp:
|
||||
if self.trainer.amp_type:
|
||||
model = self.__init_half_precision(model)
|
||||
|
||||
# init torch data parallel
|
||||
|
@ -70,9 +69,7 @@ class DataParallelBackend(object):
|
|||
return model
|
||||
|
||||
def __init_half_precision(self, model):
|
||||
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
|
||||
if native_amp_available:
|
||||
if self.trainer.amp_type == AMPType.NATIVE:
|
||||
self.__init_native_amp(model)
|
||||
else:
|
||||
model = self.__init_nvidia_apex(model)
|
||||
|
|
|
@ -12,19 +12,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
|
||||
class GPUBackend(object):
|
||||
amp_type: AMPType
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
@ -43,9 +41,7 @@ class GPUBackend(object):
|
|||
self.trainer.lr_schedulers = lr_schedulers
|
||||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||
|
||||
# TODO: remove with dropping NVIDIA AMP support
|
||||
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available:
|
||||
if self.trainer.amp_type == AMPType.APEX:
|
||||
model = self._setup_nvidia_apex(model)
|
||||
return model
|
||||
|
||||
|
|
|
@ -5,14 +5,12 @@ from torch import Tensor
|
|||
from torch.nn import Module
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import move_data_to_device, AMPType
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
|
||||
class ModelHooks(Module):
|
||||
|
@ -267,8 +265,8 @@ class ModelHooks(Module):
|
|||
"""
|
||||
loss.backward()
|
||||
|
||||
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx):
|
||||
if NATIVE_AMP_AVALAIBLE:
|
||||
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_type: AMPType):
|
||||
if amp_type == AMPType.NATIVE:
|
||||
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
|
||||
else:
|
||||
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)
|
||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
|
||||
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
|
||||
UNKNOWN_SIZE = "?"
|
||||
|
@ -207,8 +207,7 @@ class ModelSummary(object):
|
|||
input_ = model.example_input_array
|
||||
input_ = model.transfer_batch_to_device(input_, model.device)
|
||||
|
||||
if trainer is not None and trainer.use_amp and not trainer.use_tpu:
|
||||
if NATIVE_AMP_AVALAIBLE:
|
||||
if trainer is not None and trainer.amp_type == AMPType.NATIVE and not trainer.use_tpu:
|
||||
model.forward = torch.cuda.amp.autocast()(model.forward)
|
||||
|
||||
mode = model.training
|
||||
|
|
|
@ -864,6 +864,19 @@ Enable synchronization between batchnorm layers across all GPUs.
|
|||
|
||||
trainer = Trainer(sync_batchnorm=True)
|
||||
|
||||
amp_type
|
||||
^^^^^^^^
|
||||
|
||||
Define a preferable mixed precision, either NVIDIA Apex ("apex") or PyTorch built-in ("native") AMP which is supported from v1.6.
|
||||
|
||||
.. testcode::
|
||||
|
||||
# using NVIDIA Apex
|
||||
trainer = Trainer(amp_type='apex')
|
||||
|
||||
# using PyTorch built-in AMP
|
||||
trainer = Trainer(amp_type='native')
|
||||
|
||||
val_percent_check
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from abc import ABC
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType
|
||||
|
||||
|
||||
class TrainerAMPMixin(ABC):
|
||||
|
@ -10,26 +10,39 @@ class TrainerAMPMixin(ABC):
|
|||
# the proper values/initialisation should be done in child class
|
||||
precision: int
|
||||
|
||||
def init_amp(self):
|
||||
if NATIVE_AMP_AVALAIBLE:
|
||||
log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)")
|
||||
|
||||
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
||||
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE:
|
||||
log.info('Using native 16bit precision.')
|
||||
def _setup_amp_type(self, amp_type: str):
|
||||
self.amp_type = None
|
||||
if self.precision != 16:
|
||||
# no AMP requested, so we can leave now
|
||||
return
|
||||
|
||||
if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover
|
||||
amp_type = amp_type.lower()
|
||||
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
|
||||
if amp_type == 'native':
|
||||
if not NATIVE_AMP_AVALAIBLE:
|
||||
rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
|
||||
' Consider upgrading with `pip install torch>=1.6`.'
|
||||
' We will attempt to use NVIDIA Apex for this session.')
|
||||
amp_type = 'apex'
|
||||
else:
|
||||
log.info('Using native 16bit precision.')
|
||||
self.amp_type = AMPType.NATIVE
|
||||
if amp_type == 'apex':
|
||||
if not APEX_AVAILABLE:
|
||||
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
|
||||
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
|
||||
else:
|
||||
log.info('Using APEX 16bit precision.')
|
||||
self.amp_type = AMPType.APEX
|
||||
if not self.amp_type:
|
||||
raise ModuleNotFoundError(
|
||||
"You set `use_amp=True` but do not have apex installed."
|
||||
" Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
|
||||
" and rerun with `use_amp=True`."
|
||||
" This run will NOT use 16 bit precision."
|
||||
f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
|
||||
f' Consider installing torch >= 1.6 or NVIDIA Apex.'
|
||||
)
|
||||
|
||||
if self.use_amp:
|
||||
log.info('Using APEX 16bit precision.')
|
||||
def init_amp(self, amp_type: str):
|
||||
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
||||
|
||||
self._setup_amp_type(amp_type)
|
||||
|
||||
@property
|
||||
def use_amp(self) -> bool:
|
||||
|
|
|
@ -22,9 +22,7 @@ except ImportError:
|
|||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
try:
|
||||
import torch_xla
|
||||
|
|
|
@ -131,24 +131,27 @@ import os
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Union, List, Optional, Tuple
|
||||
|
||||
from typing import Union, List, Optional, Callable, Tuple
|
||||
import subprocess
|
||||
import sys
|
||||
from time import sleep
|
||||
import numpy as np
|
||||
import torch
|
||||
from os.path import abspath
|
||||
from pkg_resources import parse_version
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
|
@ -207,11 +210,6 @@ class TrainerDDPMixin(ABC):
|
|||
def num_gpus(self) -> int:
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def use_amp(self) -> bool:
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def copy_trainer_model_properties(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
|
|
@ -18,30 +18,29 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
|
|||
|
||||
"""
|
||||
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import ExitStack
|
||||
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningDistributedDataParallel,
|
||||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities import move_data_to_device, AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -80,11 +79,7 @@ class TrainerDPMixin(ABC):
|
|||
on_colab_kaggle: str
|
||||
save_spawn_weights: Callable
|
||||
logger: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def use_amp(self) -> bool:
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
amp_type: AMPType
|
||||
|
||||
@abstractmethod
|
||||
def call_setup_hook(self, *args):
|
||||
|
@ -128,7 +123,7 @@ class TrainerDPMixin(ABC):
|
|||
m.use_dp = self.use_dp
|
||||
m.use_ddp2 = self.use_ddp2
|
||||
m.use_ddp = self.use_ddp
|
||||
m.use_amp = self.use_amp
|
||||
m.use_amp = self.amp_type is not None
|
||||
m.testing = self.testing
|
||||
m.use_single_gpu = self.use_single_gpu
|
||||
m.use_tpu = self.use_tpu
|
||||
|
@ -210,7 +205,7 @@ class TrainerDPMixin(ABC):
|
|||
if isinstance(scheduler, _LRScheduler):
|
||||
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]
|
||||
|
||||
if self.use_amp:
|
||||
if self.amp_type:
|
||||
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
|
||||
self.optimizers = optimizers
|
||||
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
|
||||
|
|
|
@ -131,10 +131,11 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
|
||||
from pytorch_lightning.core.step_result import Result, EvalResult
|
||||
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE, flatten_dict
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -179,6 +180,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
tpu_id: int
|
||||
verbose_test: bool
|
||||
running_sanity_check: bool
|
||||
amp_type: AMPType
|
||||
|
||||
# Callback system
|
||||
on_validation_batch_start: Callable
|
||||
|
@ -316,7 +318,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# -----------------
|
||||
# RUN EVALUATION STEP
|
||||
# -----------------
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
|
||||
else:
|
||||
|
|
|
@ -31,7 +31,7 @@ from pytorch_lightning.core.memory import ModelSummary
|
|||
from pytorch_lightning.core.step_result import EvalResult
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.auto_mix_precision import NATIVE_AMP_AVALAIBLE, TrainerAMPMixin
|
||||
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
|
||||
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
||||
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
||||
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
|
||||
|
@ -49,7 +49,7 @@ from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
|||
from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
||||
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
||||
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
||||
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -61,9 +61,7 @@ warnings.filterwarnings(
|
|||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
try:
|
||||
import torch_xla
|
||||
|
@ -199,6 +197,7 @@ class Trainer(
|
|||
terminate_on_nan: bool = False,
|
||||
auto_scale_batch_size: Union[str, bool] = False,
|
||||
prepare_data_per_node: bool = True,
|
||||
amp_type: str = 'native',
|
||||
amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0
|
||||
val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
||||
test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
|
||||
|
@ -308,6 +307,7 @@ class Trainer(
|
|||
Defaults to `default_root_dir`.
|
||||
|
||||
amp_level: The optimization level to use (O1, O2, etc...).
|
||||
.. warning:: .. deprecated:: v0.7.4
|
||||
|
||||
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
|
||||
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
|
||||
|
@ -587,7 +587,7 @@ class Trainer(
|
|||
self.scaler = None
|
||||
|
||||
self.amp_level = amp_level
|
||||
self.init_amp()
|
||||
self.init_amp(amp_type)
|
||||
|
||||
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
||||
|
||||
|
@ -1128,7 +1128,7 @@ class Trainer(
|
|||
self.copy_trainer_model_properties(ref_model)
|
||||
|
||||
# init amp. Must be done here instead of __init__ to allow ddp to work
|
||||
if NATIVE_AMP_AVALAIBLE and self.precision == 16 and not self.use_tpu:
|
||||
if self.amp_type == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
# log hyper-parameters
|
||||
|
|
|
@ -102,7 +102,7 @@ from pytorch_lightning.overrides.data_parallel import (
|
|||
LightningDistributedDataParallel,
|
||||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
|
||||
try:
|
||||
|
@ -117,9 +117,7 @@ else:
|
|||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
try:
|
||||
import horovod.torch as hvd
|
||||
|
@ -157,8 +155,9 @@ class TrainerIOMixin(ABC):
|
|||
on_tpu: bool
|
||||
num_training_batches: int
|
||||
accumulate_grad_batches: int
|
||||
use_amp: bool
|
||||
scaler: ...
|
||||
use_tpu: bool
|
||||
amp_type: AMPType
|
||||
|
||||
def get_model(self):
|
||||
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel))
|
||||
|
@ -323,9 +322,9 @@ class TrainerIOMixin(ABC):
|
|||
model.cuda(self.root_gpu)
|
||||
|
||||
# restore amp scaling
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
|
||||
if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
|
||||
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
|
||||
elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp_scaling_state'])
|
||||
|
||||
# load training state (affects trainer only)
|
||||
|
@ -376,9 +375,9 @@ class TrainerIOMixin(ABC):
|
|||
checkpoint['lr_schedulers'] = lr_schedulers
|
||||
|
||||
# save native amp scaling
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
|
||||
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
|
||||
elif self.use_amp and not NATIVE_AMP_AVALAIBLE:
|
||||
elif self.amp_type == AMPType.APEX:
|
||||
checkpoint['amp_scaling_state'] = amp.state_dict()
|
||||
|
||||
# add the module_arguments and state_dict from the model
|
||||
|
@ -533,9 +532,9 @@ class TrainerIOMixin(ABC):
|
|||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# restore amp scaling
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
|
||||
if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
|
||||
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
|
||||
elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp_scaling_state'])
|
||||
|
||||
if self.root_gpu is not None:
|
||||
|
|
|
@ -175,7 +175,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
|
||||
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
|
@ -183,9 +183,7 @@ from pytorch_lightning.utilities.parsing import AttributeDict
|
|||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
try:
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
|
@ -255,6 +253,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
terminate_on_nan: bool
|
||||
tpu_id: int
|
||||
interactive_ddp_procs: ...
|
||||
amp_type: AMPType
|
||||
on_tpu: bool
|
||||
|
||||
# Callback system
|
||||
callbacks: List[Callback]
|
||||
|
@ -739,7 +739,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
batch_idx,
|
||||
opt_idx,
|
||||
optimizer,
|
||||
self.hiddens
|
||||
self.hiddens,
|
||||
)
|
||||
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
|
||||
|
||||
|
@ -835,7 +835,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# ------------------
|
||||
# CLIP GRADS
|
||||
# ------------------
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
|
||||
self.scaler.unscale_(optimizer)
|
||||
self.clip_gradients(optimizer)
|
||||
|
||||
|
@ -857,7 +857,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
batch_idx,
|
||||
opt_idx,
|
||||
optimizer,
|
||||
self.hiddens
|
||||
self.hiddens,
|
||||
).loss
|
||||
|
||||
# apply TPU optimizer
|
||||
|
@ -869,7 +869,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
elif isinstance(optimizer, torch.optim.LBFGS):
|
||||
|
||||
# native amp + lbfgs is a no go right now
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE:
|
||||
if self.amp_type == AMPType.NATIVE:
|
||||
raise MisconfigurationException(
|
||||
'native PyTorch amp and lbfgs are not compatible.'
|
||||
' To request, please file a Github issue in PyTorch and tag @mcarilli')
|
||||
|
@ -878,12 +878,12 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# when using 16-bit
|
||||
else:
|
||||
native_amp = self.use_amp and NATIVE_AMP_AVALAIBLE
|
||||
native_amp = self.amp_type == AMPType.NATIVE
|
||||
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure,
|
||||
using_native_amp=native_amp)
|
||||
|
||||
# in native 16-bit we need to update scaler after optimizer step
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
|
||||
self.scaler.update()
|
||||
|
||||
# model hook
|
||||
|
@ -900,7 +900,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# FORWARD (TRAINING STEP + TRAIN STEP END)
|
||||
# ---------------------------
|
||||
with self.profiler.profile('model_forward'):
|
||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
|
||||
with torch.cuda.amp.autocast():
|
||||
training_step_output = self.training_forward(split_batch, batch_idx,
|
||||
opt_idx, hiddens)
|
||||
|
@ -954,10 +954,10 @@ class TrainerTrainLoopMixin(ABC):
|
|||
with self.profiler.profile('model_backward'):
|
||||
# scale loss for 16 bit
|
||||
if self.precision == 16 and not self.on_tpu:
|
||||
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)
|
||||
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_type=self.amp_type)
|
||||
|
||||
# enter amp context
|
||||
if not NATIVE_AMP_AVALAIBLE:
|
||||
if self.amp_type == AMPType.APEX:
|
||||
context = closure_loss
|
||||
closure_loss = closure_loss.__enter__()
|
||||
|
||||
|
@ -965,7 +965,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
model_ref.backward(self, closure_loss, optimizer, opt_idx)
|
||||
|
||||
# exit amp context
|
||||
if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu:
|
||||
if self.precision == 16 and self.amp_type == AMPType.APEX and not self.on_tpu:
|
||||
a, b, c = None, None, None
|
||||
error = context.__exit__(a, b, c)
|
||||
if error:
|
||||
|
|
|
@ -24,16 +24,14 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import DummyLogger
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
APEX_AVAILABLE = False
|
||||
else:
|
||||
APEX_AVAILABLE = True
|
||||
amp = None
|
||||
|
||||
EPSILON = 1e-6
|
||||
EPSILON_FP16 = 1e-5
|
||||
|
@ -48,6 +46,7 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
default_root_dir: str
|
||||
progress_bar_callback: ...
|
||||
on_gpu: bool
|
||||
amp_type: AMPType
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self) -> LightningModule:
|
||||
|
@ -72,7 +71,7 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
if self.gradient_clip_val <= 0:
|
||||
return
|
||||
model = self.get_model()
|
||||
if self.use_amp and not NATIVE_AMP_AVALAIBLE:
|
||||
if self.amp_type == AMPType.APEX:
|
||||
parameters = amp.master_params(optimizer)
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""General utilities"""
|
||||
from enum import Enum
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
@ -19,3 +20,8 @@ NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "a
|
|||
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
|
||||
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
|
||||
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
|
||||
|
||||
|
||||
class AMPType(Enum):
|
||||
APEX = 'apex'
|
||||
NATIVE = 'native'
|
||||
|
|
Loading…
Reference in New Issue