# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager from typing import Any, Callable, Dict, Generator, Union import torch from torch import Tensor from torch.nn import Module from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE, _TORCH_CPU_AMP_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): """Plugin for native mixed precision training with :mod:`torch.cuda.amp`. Args: precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16). """ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None: super().__init__() if use_cpu and not _TORCH_CPU_AMP_AVAILABLE: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU for PyTorch 1.9 " "and lower. To use native AMP on CPU, install PyTorch 1.10 or later." ) self.use_cpu = use_cpu self._dtype = self._select_precision_dtype(precision) self.backend = AMPType.NATIVE if not self.is_bfloat16: self.scaler = torch.cuda.amp.GradScaler() def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype: if precision == "bf16": if not _TORCH_BFLOAT_AVAILABLE: raise MisconfigurationException( "To use bfloat16 with native amp you must install torch greater or equal to 1.10." ) return torch.bfloat16 elif self.use_cpu: raise MisconfigurationException( "CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer." ) return torch.float16 @property def is_bfloat16(self) -> bool: return self._dtype == torch.bfloat16 def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor: if self.is_bfloat16: return super().pre_backward(model, closure_loss) closure_loss = self.scaler.scale(closure_loss) return super().pre_backward(model, closure_loss) def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None: if not self.is_bfloat16: tensor = self.scaler.scale(tensor) super()._run_backward(tensor, model, *args, **kwargs) def pre_optimizer_step( self, model: "pl.LightningModule", optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any, ) -> bool: if self.is_bfloat16: # skip scaler logic, as bfloat16 does not require scaler return super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." ) result = lambda_closure() # native amp does not support closures self.scaler.unscale_(optimizer) super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) skipped_backward = result is None # in manual optimization, the closure does not return a value if not model.automatic_optimization or not skipped_backward: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found self.scaler.step(optimizer) self.scaler.update() return False def autocast_context_manager(self) -> torch.cuda.amp.autocast: if self.use_cpu: return torch.cpu.amp.autocast(dtype=self._dtype) # Only reached in pytorch==1.10 where this is ok. skipcq if self.is_bfloat16: return torch.cuda.amp.autocast(dtype=self._dtype) # Only reached in pytorch==1.10 where this is ok. skipcq return torch.cuda.amp.autocast() @contextmanager def forward_context(self) -> Generator[None, None, None]: """Enable autocast context.""" with self.autocast_context_manager(): yield def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if "native_amp_scaling_state" in checkpoint and not self.is_bfloat16: self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"]) def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if not self.is_bfloat16: checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()