From a6e7aa779672812e0583b6a1df473cc27f782e5b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 8 Aug 2020 11:07:32 +0200 Subject: [PATCH] allow using apex with any PT version (#2865) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * wip * setup * type * name * wip * docs * imports * fix if * fix if * use_amp * Apply suggestions from code review Co-authored-by: Adrian Wälchli * Apply suggestions from code review Co-authored-by: Adrian Wälchli * fix tests * Apply suggestions from code review Co-authored-by: Adrian Wälchli * fix tests * todos Co-authored-by: Adrian Wälchli --- dockers/cuda-extras/Dockerfile | 1 - pytorch_lightning/accelerators/cpu_backend.py | 2 +- .../accelerators/ddp2_backend.py | 12 ++--- pytorch_lightning/accelerators/ddp_backend.py | 12 ++--- .../accelerators/ddp_spawn_backend.py | 11 ++--- pytorch_lightning/accelerators/dp_backend.py | 11 ++--- pytorch_lightning/accelerators/gpu_backend.py | 12 ++--- pytorch_lightning/core/hooks.py | 10 ++-- pytorch_lightning/core/memory.py | 5 +- pytorch_lightning/trainer/__init__.py | 13 +++++ .../trainer/auto_mix_precision.py | 47 ++++++++++++------- pytorch_lightning/trainer/data_loading.py | 4 +- .../trainer/distrib_data_parallel.py | 26 +++++----- pytorch_lightning/trainer/distrib_parts.py | 29 +++++------- pytorch_lightning/trainer/evaluation_loop.py | 6 ++- pytorch_lightning/trainer/trainer.py | 14 +++--- pytorch_lightning/trainer/training_io.py | 21 ++++----- pytorch_lightning/trainer/training_loop.py | 28 +++++------ pytorch_lightning/trainer/training_tricks.py | 9 ++-- pytorch_lightning/utilities/__init__.py | 6 +++ 20 files changed, 140 insertions(+), 139 deletions(-) diff --git a/dockers/cuda-extras/Dockerfile b/dockers/cuda-extras/Dockerfile index a1aaff0e7d..f3c435ff53 100644 --- a/dockers/cuda-extras/Dockerfile +++ b/dockers/cuda-extras/Dockerfile @@ -39,7 +39,6 @@ RUN apt-get update && \ && \ # Install AMP - # TODO: skip this instrall for PT >= 1.6 bash install_AMP.sh && \ # Install all requirements pip install -r requirements.txt && \ diff --git a/pytorch_lightning/accelerators/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py index 7760442a20..cfee51e4dd 100644 --- a/pytorch_lightning/accelerators/cpu_backend.py +++ b/pytorch_lightning/accelerators/cpu_backend.py @@ -22,7 +22,7 @@ class CPUBackend(object): def setup(self, model): # run through amp wrapper - if self.trainer.use_amp: + if self.trainer.amp_type: raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option') # call setup after the ddp process has connected diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 8a4b0f6584..85bda4cd8d 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -17,7 +17,7 @@ import os import torch from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -32,9 +32,7 @@ else: try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DDP2Backend(object): @@ -135,10 +133,8 @@ class DDP2Backend(object): # set model properties before going into wrapper self.trainer.copy_trainer_model_properties(model) - # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove with dropping NVIDIA AMP support - if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE: + # AMP - run through amp wrapper before going to distributed DP + if self.trainer.amp_type == AMPType.APEX: model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.trainer.optimizers = optimizers self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 163e3faec8..e499feda65 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -23,7 +23,7 @@ import numpy as np import torch from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only try: @@ -37,9 +37,7 @@ else: try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DDPBackend(object): @@ -202,10 +200,8 @@ class DDPBackend(object): # set model properties before going into wrapper self.trainer.copy_trainer_model_properties(model) - # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove with dropping NVIDIA AMP support - if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE: + # AMP - run through amp wrapper before going to distributed DP + if self.trainer.amp_type == AMPType.APEX: model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.trainer.optimizers = optimizers self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index fe5b33af47..9ed68f6608 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -16,14 +16,13 @@ import torch import torch.multiprocessing as mp from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DDPSpawnBackend(object): @@ -133,11 +132,9 @@ class DDPSpawnBackend(object): # set model properties before going into wrapper self.trainer.copy_trainer_model_properties(model) - # AMP + # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove with dropping NVIDIA AMP support - native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") - if self.trainer.use_amp and not native_amp_available: + if self.trainer.amp_type == AMPType.APEX: model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) self.trainer.optimizers = optimizers self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index a6fb4ca3e3..31791ee5ec 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -16,14 +16,13 @@ import torch from torch import optim from pytorch_lightning.overrides.data_parallel import LightningDataParallel +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class DataParallelBackend(object): @@ -50,7 +49,7 @@ class DataParallelBackend(object): self.model_autocast_original_forward = model.forward # init half precision - if self.trainer.use_amp: + if self.trainer.amp_type: model = self.__init_half_precision(model) # init torch data parallel @@ -70,9 +69,7 @@ class DataParallelBackend(object): return model def __init_half_precision(self, model): - native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") - - if native_amp_available: + if self.trainer.amp_type == AMPType.NATIVE: self.__init_native_amp(model) else: model = self.__init_nvidia_apex(model) diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index 846457f126..30920998b2 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch - from pytorch_lightning.core import LightningModule +from pytorch_lightning.utilities import AMPType try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class GPUBackend(object): + amp_type: AMPType def __init__(self, trainer): self.trainer = trainer @@ -43,9 +41,7 @@ class GPUBackend(object): self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies - # TODO: remove with dropping NVIDIA AMP support - native_amp_available = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") - if APEX_AVAILABLE and self.trainer.use_amp and not native_amp_available: + if self.trainer.amp_type == AMPType.APEX: model = self._setup_nvidia_apex(model) return model diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index fd356fd134..1695e090f0 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -5,14 +5,12 @@ from torch import Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer -from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import move_data_to_device, AMPType try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None class ModelHooks(Module): @@ -267,8 +265,8 @@ class ModelHooks(Module): """ loss.backward() - def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx): - if NATIVE_AMP_AVALAIBLE: + def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_type: AMPType): + if amp_type == AMPType.NATIVE: scaled_loss = self.trainer.scaler.scale(unscaled_loss) else: scaled_loss = amp.scale_loss(unscaled_loss, optimizer) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index b894983a8f..f55b8e026c 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn from torch.utils.hooks import RemovableHandle -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import AMPType PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"] UNKNOWN_SIZE = "?" @@ -207,8 +207,7 @@ class ModelSummary(object): input_ = model.example_input_array input_ = model.transfer_batch_to_device(input_, model.device) - if trainer is not None and trainer.use_amp and not trainer.use_tpu: - if NATIVE_AMP_AVALAIBLE: + if trainer is not None and trainer.amp_type == AMPType.NATIVE and not trainer.use_tpu: model.forward = torch.cuda.amp.autocast()(model.forward) mode = model.training diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index a3486c27ed..6b8a58ace5 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -864,6 +864,19 @@ Enable synchronization between batchnorm layers across all GPUs. trainer = Trainer(sync_batchnorm=True) +amp_type +^^^^^^^^ + +Define a preferable mixed precision, either NVIDIA Apex ("apex") or PyTorch built-in ("native") AMP which is supported from v1.6. + +.. testcode:: + + # using NVIDIA Apex + trainer = Trainer(amp_type='apex') + + # using PyTorch built-in AMP + trainer = Trainer(amp_type='native') + val_percent_check ^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/auto_mix_precision.py b/pytorch_lightning/trainer/auto_mix_precision.py index 71ff71d0ec..06e823117c 100644 --- a/pytorch_lightning/trainer/auto_mix_precision.py +++ b/pytorch_lightning/trainer/auto_mix_precision.py @@ -1,7 +1,7 @@ from abc import ABC from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType class TrainerAMPMixin(ABC): @@ -10,26 +10,39 @@ class TrainerAMPMixin(ABC): # the proper values/initialisation should be done in child class precision: int - def init_amp(self): - if NATIVE_AMP_AVALAIBLE: - log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)") - - assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' - - if self.use_amp and NATIVE_AMP_AVALAIBLE: - log.info('Using native 16bit precision.') + def _setup_amp_type(self, amp_type: str): + self.amp_type = None + if self.precision != 16: + # no AMP requested, so we can leave now return - - if self.use_amp and not APEX_AVAILABLE: # pragma: no-cover + amp_type = amp_type.lower() + assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' + if amp_type == 'native': + if not NATIVE_AMP_AVALAIBLE: + rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.' + ' Consider upgrading with `pip install torch>=1.6`.' + ' We will attempt to use NVIDIA Apex for this session.') + amp_type = 'apex' + else: + log.info('Using native 16bit precision.') + self.amp_type = AMPType.NATIVE + if amp_type == 'apex': + if not APEX_AVAILABLE: + rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' + ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') + else: + log.info('Using APEX 16bit precision.') + self.amp_type = AMPType.APEX + if not self.amp_type: raise ModuleNotFoundError( - "You set `use_amp=True` but do not have apex installed." - " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" - " and rerun with `use_amp=True`." - " This run will NOT use 16 bit precision." + f'You have asked for AMP support {amp_type}, but there is no support on your side yet.' + f' Consider installing torch >= 1.6 or NVIDIA Apex.' ) - if self.use_amp: - log.info('Using APEX 16bit precision.') + def init_amp(self, amp_type: str): + assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' + + self._setup_amp_type(amp_type) @property def use_amp(self) -> bool: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 38a1118118..575e28354b 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -22,9 +22,7 @@ except ImportError: try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index f956f63307..c550fb648f 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -131,24 +131,27 @@ import os import re from abc import ABC, abstractmethod from distutils.version import LooseVersion -from typing import Union, List, Optional, Tuple - +from typing import Union, List, Optional, Callable, Tuple +import subprocess +import sys +from time import sleep import numpy as np -import torch +from os.path import abspath +from pkg_resources import parse_version +import torch from pytorch_lightning import _logger as log +from pytorch_lightning.loggers import LightningLoggerBase +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info -from pytorch_lightning.utilities.exceptions import MisconfigurationException + try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import horovod.torch as hvd @@ -207,11 +210,6 @@ class TrainerDDPMixin(ABC): def num_gpus(self) -> int: """Warning: this is just empty shell for code implemented in other class.""" - @property - @abstractmethod - def use_amp(self) -> bool: - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def copy_trainer_model_properties(self, *args): """Warning: this is just empty shell for code implemented in other class.""" diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index f647f9ca3f..f76c6f1b00 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -18,30 +18,29 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. """ -import random -import time -from abc import ABC, abstractmethod from contextlib import ExitStack -from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence - +import os +from abc import ABC, abstractmethod +import time +import random import torch from torch.optim.lr_scheduler import _LRScheduler +from typing import Union, Callable, Any, List, Optional, Tuple, MutableSequence from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning import _logger as log from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities import move_data_to_device -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities import move_data_to_device, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.distributed import rank_zero_only try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla.core.xla_model as xm @@ -80,11 +79,7 @@ class TrainerDPMixin(ABC): on_colab_kaggle: str save_spawn_weights: Callable logger: ... - - @property - @abstractmethod - def use_amp(self) -> bool: - """Warning: this is just empty shell for code implemented in other class.""" + amp_type: AMPType @abstractmethod def call_setup_hook(self, *args): @@ -128,7 +123,7 @@ class TrainerDPMixin(ABC): m.use_dp = self.use_dp m.use_ddp2 = self.use_ddp2 m.use_ddp = self.use_ddp - m.use_amp = self.use_amp + m.use_amp = self.amp_type is not None m.testing = self.testing m.use_single_gpu = self.use_single_gpu m.use_tpu = self.use_tpu @@ -210,7 +205,7 @@ class TrainerDPMixin(ABC): if isinstance(scheduler, _LRScheduler): scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] - if self.use_amp: + if self.amp_type: model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) self.optimizers = optimizers self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f2ad3cb4ef..5ce7b7718c 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -131,10 +131,11 @@ import torch from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType from pytorch_lightning.core.step_result import Result, EvalResult -from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE, flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException + try: import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.core.xla_model as xm @@ -179,6 +180,7 @@ class TrainerEvaluationLoopMixin(ABC): tpu_id: int verbose_test: bool running_sanity_check: bool + amp_type: AMPType # Callback system on_validation_batch_start: Callable @@ -316,7 +318,7 @@ class TrainerEvaluationLoopMixin(ABC): # ----------------- # RUN EVALUATION STEP # ----------------- - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) else: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d869131944..b8d246b5c8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -31,7 +31,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import EvalResult from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler -from pytorch_lightning.trainer.auto_mix_precision import NATIVE_AMP_AVALAIBLE, TrainerAMPMixin +from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator @@ -49,7 +49,7 @@ from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin -from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -61,9 +61,7 @@ warnings.filterwarnings( try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla @@ -199,6 +197,7 @@ class Trainer( terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, + amp_type: str = 'native', amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0 val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0 @@ -308,6 +307,7 @@ class Trainer( Defaults to `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). + .. warning:: .. deprecated:: v0.7.4 num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: 2 @@ -587,7 +587,7 @@ class Trainer( self.scaler = None self.amp_level = amp_level - self.init_amp() + self.init_amp(amp_type) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') @@ -1128,7 +1128,7 @@ class Trainer( self.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work - if NATIVE_AMP_AVALAIBLE and self.precision == 16 and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and self.precision == 16 and not self.use_tpu: self.scaler = torch.cuda.amp.GradScaler() # log hyper-parameters diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 30e3d7c0e3..b112f0dbad 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -102,7 +102,7 @@ from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.cloud_io import load as pl_load try: @@ -117,9 +117,7 @@ else: try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import horovod.torch as hvd @@ -157,8 +155,9 @@ class TrainerIOMixin(ABC): on_tpu: bool num_training_batches: int accumulate_grad_batches: int - use_amp: bool scaler: ... + use_tpu: bool + amp_type: AMPType def get_model(self): is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel)) @@ -323,9 +322,9 @@ class TrainerIOMixin(ABC): model.cuda(self.root_gpu) # restore amp scaling - if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint: + if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint: + elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint: amp.load_state_dict(checkpoint['amp_scaling_state']) # load training state (affects trainer only) @@ -376,9 +375,9 @@ class TrainerIOMixin(ABC): checkpoint['lr_schedulers'] = lr_schedulers # save native amp scaling - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() - elif self.use_amp and not NATIVE_AMP_AVALAIBLE: + elif self.amp_type == AMPType.APEX: checkpoint['amp_scaling_state'] = amp.state_dict() # add the module_arguments and state_dict from the model @@ -533,9 +532,9 @@ class TrainerIOMixin(ABC): model.load_state_dict(checkpoint['state_dict']) # restore amp scaling - if self.use_amp and NATIVE_AMP_AVALAIBLE and 'native_amp_scaling_state' in checkpoint: + if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: self.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) - elif self.use_amp and not NATIVE_AMP_AVALAIBLE and 'amp_scaling_state' in checkpoint: + elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint: amp.load_state_dict(checkpoint['amp_scaling_state']) if self.root_gpu is not None: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 757826393f..72178b8e8e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -175,7 +175,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator -from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import rank_zero_warn, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.parsing import AttributeDict @@ -183,9 +183,7 @@ from pytorch_lightning.utilities.parsing import AttributeDict try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None try: import torch_xla.distributed.parallel_loader as xla_pl @@ -255,6 +253,8 @@ class TrainerTrainLoopMixin(ABC): terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... + amp_type: AMPType + on_tpu: bool # Callback system callbacks: List[Callback] @@ -739,7 +739,7 @@ class TrainerTrainLoopMixin(ABC): batch_idx, opt_idx, optimizer, - self.hiddens + self.hiddens, ) using_results_obj = isinstance(opt_closure_result.training_step_output, Result) @@ -835,7 +835,7 @@ class TrainerTrainLoopMixin(ABC): # ------------------ # CLIP GRADS # ------------------ - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: self.scaler.unscale_(optimizer) self.clip_gradients(optimizer) @@ -857,7 +857,7 @@ class TrainerTrainLoopMixin(ABC): batch_idx, opt_idx, optimizer, - self.hiddens + self.hiddens, ).loss # apply TPU optimizer @@ -869,7 +869,7 @@ class TrainerTrainLoopMixin(ABC): elif isinstance(optimizer, torch.optim.LBFGS): # native amp + lbfgs is a no go right now - if self.use_amp and NATIVE_AMP_AVALAIBLE: + if self.amp_type == AMPType.NATIVE: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' ' To request, please file a Github issue in PyTorch and tag @mcarilli') @@ -878,12 +878,12 @@ class TrainerTrainLoopMixin(ABC): # when using 16-bit else: - native_amp = self.use_amp and NATIVE_AMP_AVALAIBLE + native_amp = self.amp_type == AMPType.NATIVE model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, using_native_amp=native_amp) # in native 16-bit we need to update scaler after optimizer step - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: self.scaler.update() # model hook @@ -900,7 +900,7 @@ class TrainerTrainLoopMixin(ABC): # FORWARD (TRAINING STEP + TRAIN STEP END) # --------------------------- with self.profiler.profile('model_forward'): - if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: + if self.amp_type == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): training_step_output = self.training_forward(split_batch, batch_idx, opt_idx, hiddens) @@ -954,10 +954,10 @@ class TrainerTrainLoopMixin(ABC): with self.profiler.profile('model_backward'): # scale loss for 16 bit if self.precision == 16 and not self.on_tpu: - closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx) + closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_type=self.amp_type) # enter amp context - if not NATIVE_AMP_AVALAIBLE: + if self.amp_type == AMPType.APEX: context = closure_loss closure_loss = closure_loss.__enter__() @@ -965,7 +965,7 @@ class TrainerTrainLoopMixin(ABC): model_ref.backward(self, closure_loss, optimizer, opt_idx) # exit amp context - if self.precision == 16 and not NATIVE_AMP_AVALAIBLE and not self.on_tpu: + if self.precision == 16 and self.amp_type == AMPType.APEX and not self.on_tpu: a, b, c = None, None, None error = context.__exit__(a, b, c) if error: diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 9aaa8730e9..14214d2432 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -24,16 +24,14 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import DummyLogger -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda try: from apex import amp except ImportError: - APEX_AVAILABLE = False -else: - APEX_AVAILABLE = True + amp = None EPSILON = 1e-6 EPSILON_FP16 = 1e-5 @@ -48,6 +46,7 @@ class TrainerTrainingTricksMixin(ABC): default_root_dir: str progress_bar_callback: ... on_gpu: bool + amp_type: AMPType @abstractmethod def get_model(self) -> LightningModule: @@ -72,7 +71,7 @@ class TrainerTrainingTricksMixin(ABC): if self.gradient_clip_val <= 0: return model = self.get_model() - if self.use_amp and not NATIVE_AMP_AVALAIBLE: + if self.amp_type == AMPType.APEX: parameters = amp.master_params(optimizer) else: parameters = model.parameters() diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 445b6145b2..a6ea1d8467 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1,4 +1,5 @@ """General utilities""" +from enum import Enum import numpy import torch @@ -19,3 +20,8 @@ NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "a FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps + + +class AMPType(Enum): + APEX = 'apex' + NATIVE = 'native'