Mixed precision training delivers significant computational speedup by conducting operations in half-precision while keeping minimum information in single-precision to maintain as much information as possible in crucial areas of the network.
Switching to mixed precision has resulted in considerable training speedups since the introduction of Tensor Cores in the Volta and Turing architectures.
It combines FP32 and lower-bit floating points (such as FP16) to reduce memory footprint and increase performance during model training and evaluation.
It accomplishes this by recognizing the steps that require complete accuracy and employing a 32-bit floating point for those steps only while using a 16-bit floating point for the rest.
Compared to complete precision training, mixed precision training delivers all these benefits while ensuring no task-specific accuracy is lost `[1] <https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html>`_.
Supported `PyTorch operations <https://pytorch.org/docs/stable/amp.html#op-specific-behavior>`_ automatically run in FP16, saving memory and improving throughput on the supported accelerators.
This is handled internally by a dynamic grad scaler which skips invalid steps and adjusts the scaler to ensure subsequent steps fall within a finite range.
For more information `see the autocast docs <https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_.
BFloat16 Mixed precision is similar to FP16 mixed precision. However, it maintains more of the "dynamic range" that FP32 offers.
This means it can improve numerical stability than FP16 mixed precision.
For more information, see `this TPU performance blog post <https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus>`_.
For GPUs, the most significant benefits require `Ampere <https://en.wikipedia.org/wiki/Ampere_(microarchitecture)>`_ based GPUs or newer, such as A100s or 3090s.
Under the hood, we use `transformer_engine.pytorch.fp8_autocast <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.fp8_autocast>`__ with the default fp8 recipe.
..note::
This requires `Hopper <https://en.wikipedia.org/wiki/Hopper_(microarchitecture)>`_ based GPUs or newer, such the H100.
As mentioned before, for numerical stability mixed precision keeps the model weights in full float32 precision while casting only supported operations to lower bit precision.
However, in some cases it is indeed possible to train completely in half precision. Similarly, for inference the model weights can often be cast to half precision without a loss in accuracy (even when trained with mixed precision).
..code-block:: python
# Select FP16 precision
fabric = Fabric(precision="16-true")
model = MyModel()
model = fabric.setup(model) # model gets cast to torch.float16
# Select BF16 precision
fabric = Fabric(precision="bf16-true")
model = MyModel()
model = fabric.setup(model) # model gets cast to torch.bfloat16
Tip: For faster initialization, you can create model parameters with the desired dtype directly on the device:
..code-block:: python
fabric = Fabric(precision="bf16-true")
# init the model directly on the device and with parameters in half-precision
Both 4-bit (`paper reference <https://arxiv.org/abs/2305.14314v1>`__) and 8-bit (`paper reference <https://arxiv.org/abs/2110.02861>`__) quantization is supported.
***nf4**: Uses the normalized float 4-bit data type. This is recommended over "fp4" based on the paper's experimental results and theoretical analysis.
***nf4-dq**: "dq" stands for "Double Quantization" which reduces the average memory footprint by quantizing the quantization constants. In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).
***fp4**: Uses regular float 4-bit data type.
***fp4-dq**: "dq" stands for "Double Quantization" which reduces the average memory footprint by quantizing the quantization constants. In average, this amounts to about 0.37 bits per parameter (approximately 3 GB for a 65B model).
***int8**: Uses unsigned int8 data type.
***int8-training**: Meant for int8 activations with fp16 precision weights.
The :class:`~lightning.fabric.plugins.precision.bitsandbytes.BitsandbytesPrecision` automatically replaces the :class:`torch.nn.Linear` layers in your model with their BNB alternatives.
For certain scientific computations, 64-bit precision enables more accurate models. However, doubling the precision from 32 to 64 bit also doubles the memory requirements.
..code-block:: python
# Select FP64 precision
fabric = Fabric(precision="64-true")
model = MyModel()
model = fabric.setup(model) # model gets cast to torch.float64
Since in deep learning, memory is always a bottleneck, especially when dealing with a large volume of data and with limited resources.
It is recommended using single precision for better speed. Although you can still use it if you want for your particular use-case.
When working with complex numbers, instantiation of complex tensors should be done under the
:meth:`~lightning.fabric.fabric.Fabric.init_module` context manager so that the `complex128` dtype
is properly selected.
..code-block:: python
fabric = Fabric(precision="64-true")
# init the model directly on the device and with parameters in full-precision
If you want to enable operations in lower bit-precision **outside** your model's ``forward()``, you can use the :meth:`~lightning.fabric.fabric.Fabric.autocast` context manager: