allow using apex with any PT version (#2865)

* wip

* setup

* type

* name

* wip

* docs

* imports

* fix if

* fix if

* use_amp

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* fix tests

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* fix tests

* todos

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Jirka Borovec 2020-08-08 11:07:32 +02:00 committed by GitHub
parent fed0ac838b
commit a6e7aa7796
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 140 additions and 139 deletions

View File

@ -39,7 +39,6 @@ RUN apt-get update && \
&& \
# Install AMP
# TODO: skip this instrall for PT >= 1.6
bash install_AMP.sh && \
# Install all requirements
pip install -r requirements.txt && \

View File

@ -22,7 +22,7 @@ class CPUBackend(object):
def setup(self, model):
# run through amp wrapper
if self.trainer.use_amp:
if self.trainer.amp_type:
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')
# call setup after the ddp process has connected

View File

@ -17,7 +17,7 @@ import os
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -32,9 +32,7 @@ else:
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
class DDP2Backend(object):
@ -135,10 +133,8 @@ class DDP2Backend(object):
# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)
# AMP
# run through amp wrapper before going to distributed DP
# TODO: remove with dropping NVIDIA AMP support
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
# AMP - run through amp wrapper before going to distributed DP
if self.trainer.amp_type == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

View File

@ -23,7 +23,7 @@ import numpy as np
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
try:
@ -37,9 +37,7 @@ else:
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
class DDPBackend(object):
@ -202,10 +200,8 @@ class DDPBackend(object):
# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)
# AMP
# run through amp wrapper before going to distributed DP
# TODO: remove with dropping NVIDIA AMP support
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
# AMP - run through amp wrapper before going to distributed DP
if self.trainer.amp_type == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

View File

@ -16,14 +16,13 @@ import torch
import torch.multiprocessing as mp
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
class DDPSpawnBackend(object):
@ -133,11 +132,9 @@ class DDPSpawnBackend(object):
# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)
# AMP
# AMP -
# run through amp wrapper before going to distributed DP
# TODO: remove with dropping NVIDIA AMP support
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if self.trainer.use_amp and not native_amp_available:
if self.trainer.amp_type == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

View File

@ -16,14 +16,13 @@ import torch
from torch import optim
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
class DataParallelBackend(object):
@ -50,7 +49,7 @@ class DataParallelBackend(object):
self.model_autocast_original_forward = model.forward
# init half precision
if self.trainer.use_amp:
if self.trainer.amp_type:
model = self.__init_half_precision(model)
# init torch data parallel
@ -70,9 +69,7 @@ class DataParallelBackend(object):
return model
def __init_half_precision(self, model):
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if native_amp_available:
if self.trainer.amp_type == AMPType.NATIVE:
self.__init_native_amp(model)
else:
model = self.__init_nvidia_apex(model)

View File

@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import AMPType
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
class GPUBackend(object):
amp_type: AMPType
def __init__(self, trainer):
self.trainer = trainer
@ -43,9 +41,7 @@ class GPUBackend(object):
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies
# TODO: remove with dropping NVIDIA AMP support
native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available:
if self.trainer.amp_type == AMPType.APEX:
model = self._setup_nvidia_apex(model)
return model

View File

@ -5,14 +5,12 @@ from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import move_data_to_device, AMPType
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
class ModelHooks(Module):
@ -267,8 +265,8 @@ class ModelHooks(Module):
"""
loss.backward()
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx):
if NATIVE_AMP_AVALAIBLE:
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_type: AMPType):
if amp_type == AMPType.NATIVE:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
else:
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)

View File

@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torch.utils.hooks import RemovableHandle
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import AMPType
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
UNKNOWN_SIZE = "?"
@ -207,8 +207,7 @@ class ModelSummary(object):
input_ = model.example_input_array
input_ = model.transfer_batch_to_device(input_, model.device)
if trainer is not None and trainer.use_amp and not trainer.use_tpu:
if NATIVE_AMP_AVALAIBLE:
if trainer is not None and trainer.amp_type == AMPType.NATIVE and not trainer.use_tpu:
model.forward = torch.cuda.amp.autocast()(model.forward)
mode = model.training

View File

@ -864,6 +864,19 @@ Enable synchronization between batchnorm layers across all GPUs.
trainer = Trainer(sync_batchnorm=True)
amp_type
^^^^^^^^
Define a preferable mixed precision, either NVIDIA Apex ("apex") or PyTorch built-in ("native") AMP which is supported from v1.6.
.. testcode::
# using NVIDIA Apex
trainer = Trainer(amp_type='apex')
# using PyTorch built-in AMP
trainer = Trainer(amp_type='native')
val_percent_check
^^^^^^^^^^^^^^^^^

View File

@ -1,7 +1,7 @@
from abc import ABC
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType
class TrainerAMPMixin(ABC):
@ -10,26 +10,39 @@ class TrainerAMPMixin(ABC):
# the proper values/initialisation should be done in child class
precision: int
def init_amp(self):
if NATIVE_AMP_AVALAIBLE:
log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)")
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
if self.use_amp and NATIVE_AMP_AVALAIBLE:
log.info('Using native 16bit precision.')
def _setup_amp_type(self, amp_type: str):
self.amp_type = None
if self.precision != 16:
# no AMP requested, so we can leave now
return
if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover
amp_type = amp_type.lower()
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
if amp_type == 'native':
if not NATIVE_AMP_AVALAIBLE:
rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
' Consider upgrading with `pip install torch>=1.6`.'
' We will attempt to use NVIDIA Apex for this session.')
amp_type = 'apex'
else:
log.info('Using native 16bit precision.')
self.amp_type = AMPType.NATIVE
if amp_type == 'apex':
if not APEX_AVAILABLE:
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
else:
log.info('Using APEX 16bit precision.')
self.amp_type = AMPType.APEX
if not self.amp_type:
raise ModuleNotFoundError(
"You set `use_amp=True` but do not have apex installed."
" Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
" and rerun with `use_amp=True`."
" This run will NOT use 16 bit precision."
f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
f' Consider installing torch >= 1.6 or NVIDIA Apex.'
)
if self.use_amp:
log.info('Using APEX 16bit precision.')
def init_amp(self, amp_type: str):
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
self._setup_amp_type(amp_type)
@property
def use_amp(self) -> bool:

View File

@ -22,9 +22,7 @@ except ImportError:
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
try:
import torch_xla

View File

@ -131,24 +131,27 @@ import os
import re
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Optional, Tuple
from typing import Union, List, Optional, Callable, Tuple
import subprocess
import sys
from time import sleep
import numpy as np
import torch
from os.path import abspath
from pkg_resources import parse_version
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
try:
import horovod.torch as hvd
@ -207,11 +210,6 @@ class TrainerDDPMixin(ABC):
def num_gpus(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""
@property
@abstractmethod
def use_amp(self) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

View File

@ -18,30 +18,29 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
"""
import random
import time
from abc import ABC, abstractmethod
from contextlib import ExitStack
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence
import os
from abc import ABC, abstractmethod
import time
import random
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import move_data_to_device, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
try:
import torch_xla.core.xla_model as xm
@ -80,11 +79,7 @@ class TrainerDPMixin(ABC):
on_colab_kaggle: str
save_spawn_weights: Callable
logger: ...
@property
@abstractmethod
def use_amp(self) -> bool:
"""Warning: this is just empty shell for code implemented in other class."""
amp_type: AMPType
@abstractmethod
def call_setup_hook(self, *args):
@ -128,7 +123,7 @@ class TrainerDPMixin(ABC):
m.use_dp = self.use_dp
m.use_ddp2 = self.use_ddp2
m.use_ddp = self.use_ddp
m.use_amp = self.use_amp
m.use_amp = self.amp_type is not None
m.testing = self.testing
m.use_single_gpu = self.use_single_gpu
m.use_tpu = self.use_tpu
@ -210,7 +205,7 @@ class TrainerDPMixin(ABC):
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]
if self.use_amp:
if self.amp_type:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

View File

@ -131,10 +131,11 @@ import torch
from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE, flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm
@ -179,6 +180,7 @@ class TrainerEvaluationLoopMixin(ABC):
tpu_id: int
verbose_test: bool
running_sanity_check: bool
amp_type: AMPType
# Callback system
on_validation_batch_start: Callable
@ -316,7 +318,7 @@ class TrainerEvaluationLoopMixin(ABC):
# -----------------
# RUN EVALUATION STEP
# -----------------
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
with torch.cuda.amp.autocast():
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
else:

View File

@ -31,7 +31,7 @@ from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import EvalResult
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.trainer.auto_mix_precision import NATIVE_AMP_AVALAIBLE, TrainerAMPMixin
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
@ -49,7 +49,7 @@ from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -61,9 +61,7 @@ warnings.filterwarnings(
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
try:
import torch_xla
@ -199,6 +197,7 @@ class Trainer(
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
amp_type: str = 'native',
amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0
val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
@ -308,6 +307,7 @@ class Trainer(
Defaults to `default_root_dir`.
amp_level: The optimization level to use (O1, O2, etc...).
.. warning:: .. deprecated:: v0.7.4
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
@ -587,7 +587,7 @@ class Trainer(
self.scaler = None
self.amp_level = amp_level
self.init_amp()
self.init_amp(amp_type)
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
@ -1128,7 +1128,7 @@ class Trainer(
self.copy_trainer_model_properties(ref_model)
# init amp. Must be done here instead of __init__ to allow ddp to work
if NATIVE_AMP_AVALAIBLE and self.precision == 16 and not self.use_tpu:
if self.amp_type == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
self.scaler = torch.cuda.amp.GradScaler()
# log hyper-parameters

View File

@ -102,7 +102,7 @@ from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.cloud_io import load as pl_load
try:
@ -117,9 +117,7 @@ else:
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
try:
import horovod.torch as hvd
@ -157,8 +155,9 @@ class TrainerIOMixin(ABC):
on_tpu: bool
num_training_batches: int
accumulate_grad_batches: int
use_amp: bool
scaler: ...
use_tpu: bool
amp_type: AMPType
def get_model(self):
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel))
@ -323,9 +322,9 @@ class TrainerIOMixin(ABC):
model.cuda(self.root_gpu)
# restore amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])
# load training state (affects trainer only)
@ -376,9 +375,9 @@ class TrainerIOMixin(ABC):
checkpoint['lr_schedulers'] = lr_schedulers
# save native amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
elif self.use_amp and not NATIVE_AMP_AVALAIBLE:
elif self.amp_type == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()
# add the module_arguments and state_dict from the model
@ -533,9 +532,9 @@ class TrainerIOMixin(ABC):
model.load_state_dict(checkpoint['state_dict'])
# restore amp scaling
if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint:
if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint:
elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])
if self.root_gpu is not None:

View File

@ -175,7 +175,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.parsing import AttributeDict
@ -183,9 +183,7 @@ from pytorch_lightning.utilities.parsing import AttributeDict
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
try:
import torch_xla.distributed.parallel_loader as xla_pl
@ -255,6 +253,8 @@ class TrainerTrainLoopMixin(ABC):
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...
amp_type: AMPType
on_tpu: bool
# Callback system
callbacks: List[Callback]
@ -739,7 +739,7 @@ class TrainerTrainLoopMixin(ABC):
batch_idx,
opt_idx,
optimizer,
self.hiddens
self.hiddens,
)
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
@ -835,7 +835,7 @@ class TrainerTrainLoopMixin(ABC):
# ------------------
# CLIP GRADS
# ------------------
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
self.scaler.unscale_(optimizer)
self.clip_gradients(optimizer)
@ -857,7 +857,7 @@ class TrainerTrainLoopMixin(ABC):
batch_idx,
opt_idx,
optimizer,
self.hiddens
self.hiddens,
).loss
# apply TPU optimizer
@ -869,7 +869,7 @@ class TrainerTrainLoopMixin(ABC):
elif isinstance(optimizer, torch.optim.LBFGS):
# native amp + lbfgs is a no go right now
if self.use_amp and NATIVE_AMP_AVALAIBLE:
if self.amp_type == AMPType.NATIVE:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
@ -878,12 +878,12 @@ class TrainerTrainLoopMixin(ABC):
# when using 16-bit
else:
native_amp = self.use_amp and NATIVE_AMP_AVALAIBLE
native_amp = self.amp_type == AMPType.NATIVE
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure,
using_native_amp=native_amp)
# in native 16-bit we need to update scaler after optimizer step
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
self.scaler.update()
# model hook
@ -900,7 +900,7 @@ class TrainerTrainLoopMixin(ABC):
# FORWARD (TRAINING STEP + TRAIN STEP END)
# ---------------------------
with self.profiler.profile('model_forward'):
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
with torch.cuda.amp.autocast():
training_step_output = self.training_forward(split_batch, batch_idx,
opt_idx, hiddens)
@ -954,10 +954,10 @@ class TrainerTrainLoopMixin(ABC):
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx)
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_type=self.amp_type)
# enter amp context
if not NATIVE_AMP_AVALAIBLE:
if self.amp_type == AMPType.APEX:
context = closure_loss
closure_loss = closure_loss.__enter__()
@ -965,7 +965,7 @@ class TrainerTrainLoopMixin(ABC):
model_ref.backward(self, closure_loss, optimizer, opt_idx)
# exit amp context
if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu:
if self.precision == 16 and self.amp_type == AMPType.APEX and not self.on_tpu:
a, b, c = None, None, None
error = context.__exit__(a, b, c)
if error:

View File

@ -24,16 +24,14 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True
amp = None
EPSILON = 1e-6
EPSILON_FP16 = 1e-5
@ -48,6 +46,7 @@ class TrainerTrainingTricksMixin(ABC):
default_root_dir: str
progress_bar_callback: ...
on_gpu: bool
amp_type: AMPType
@abstractmethod
def get_model(self) -> LightningModule:
@ -72,7 +71,7 @@ class TrainerTrainingTricksMixin(ABC):
if self.gradient_clip_val <= 0:
return
model = self.get_model()
if self.use_amp and not NATIVE_AMP_AVALAIBLE:
if self.amp_type == AMPType.APEX:
parameters = amp.master_params(optimizer)
else:
parameters = model.parameters()

View File

@ -1,4 +1,5 @@
"""General utilities"""
from enum import Enum
import numpy
import torch
@ -19,3 +20,8 @@ NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "a
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps
class AMPType(Enum):
APEX = 'apex'
NATIVE = 'native'