Docs for Pruning, Quantization, and SWA (#6041)

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
edenlightning 2021-02-18 08:51:51 -05:00 committed by GitHub
parent f48a9330ed
commit 3449e2d79f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 181 additions and 27 deletions

View File

@ -0,0 +1,119 @@
.. testsetup:: *
import os
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.core.lightning import LightningModule
.. _pruning_quantization:
########################
Pruning and Quantization
########################
Pruning and Quantization are techniques to compress model size for deployment, allowing inference speed up and energy saving without significant accuracy losses.
*******
Pruning
*******
.. warning::
Pruning is in beta and subject to change.
Pruning is a technique which focuses on eliminating some of the model weights to reduce the model size and decrease inference requirements.
Pruning has been shown to achieve significant efficiency improvements while minimizing the drop in model performance (prediction quality). Model pruning is recommended for cloud endpoints, deploying models on edge devices, or mobile inference (among others).
To enable pruning during training in Lightning, simply pass in the :class:`~pytorch_lightning.callbacks.ModelPruning` callback to the Lightning Trainer. PyTorch's native pruning implementation is used under the hood.
This callback supports multiple pruning functions: pass any `torch.nn.utils.prune <https://pytorch.org/docs/stable/nn.html#utilities>`_ function as a string to select which weights to prune (`random_unstructured <https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.random_unstructured.html#torch.nn.utils.prune.random_unstructured>`_, `RandomStructured <https://pytorch.org/docs/stable/generated/torch.nn.utils.prune.RandomStructured.html#torch.nn.utils.prune.RandomStructured>`_, etc) or implement your own by subclassing `BasePruningMethod <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#extending-torch-nn-utils-prune-with-custom-pruning-functions>`_.
.. code-block:: python
from pytorch_lightning.callbacks import ModelPruning
# set the amount to be the fraction of parameters to prune
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=0.5)])
You can also perform iterative pruning, apply the `lottery ticket hypothesis <https://arxiv.org/pdf/1803.03635.pdf>`__, and more!
.. code-block:: python
def compute_amount(epoch):
# the sum of all returned values need to be smaller than 1
if epoch == 10:
return 0.5
elif epoch == 50:
return 0.25
elif 75 < epoch < 99 :
return 0.01
# the amount can be also be a callable
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=compute_amount)])
************
Quantization
************
.. warning ::
Quantization is in beta and subject to change.
Model quantization is another performance optimization technique that allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating-point precision. This is particularly beneficial during model deployment.
Quantization Aware Training (QAT) mimics the effects of quantization during training: The computations are carried-out in floating-point precision but the subsequent quantization effect is taken into account. The weights and activations are quantized into lower precision only for inference, when training is completed.
Quantization is useful when it is required to serve large models on machines with limited memory, or when there's a need to switch between models and reducing the I/O time is important. For example, switching between monolingual speech recognition models across multiple languages.
Lightning includes :class:`~pytorch_lightning.callbacks.QuantizationAwareTraining` callback (using PyTorch's native quantization, read more `here <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`__), which allows creating fully quantized models (compatible with torchscript).
.. code-block:: python
from pytorch_lightning.callbacks import QuantizationAwareTraining
class RegressionModel(LightningModule):
def __init__(self):
super().__init__()
self.layer_0 = nn.Linear(16, 64)
self.layer_0a = torch.nn.ReLU()
self.layer_1 = nn.Linear(64, 64)
self.layer_1a = torch.nn.ReLU()
self.layer_end = nn.Linear(64, 1)
def forward(self, x):
x = self.layer_0(x)
x = self.layer_0a(x)
x = self.layer_1(x)
x = self.layer_1a(x)
x = self.layer_end(x)
return x
trainer = Trainer(callbacks=[QuantizationAwareTraining()])
qmodel = RegressionModel()
trainer.fit(qmodel, ...)
batch = iter(my_dataloader()).next()
qmodel(qmodel.quant(batch[0]))
tsmodel = qmodel.to_torchscript()
tsmodel(tsmodel.quant(batch[0]))
You can further customize the callback:
.. code-block:: python
qcb = QuantizationAwareTraining(
# specification of quant estimation quality
observer_type='histogram',
# specify which layers shall be merged together to increase efficiency
modules_to_fuse=[(f'layer_{i}', f'layer_{i}a') for i in range(2)]
# make your model compatible with all original input/outputs, in such case the model is wrapped in a shell with entry/exit layers.
input_compatible=True
)
batch = iter(my_dataloader()).next()
qmodel(batch[0])

View File

@ -41,6 +41,24 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
----------
Stochastic Weight Averaging
---------------------------
Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost.
This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making
it harder to end up in a local minimum during optimization.
For a more detailed explanation of SWA and how it works,
read `this <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`__ post by the PyTorch team.
.. seealso:: :class:`~pytorch_lightning.callbacks.StochasticWeightAveraging` (Callback)
.. testcode::
# Enable Stochastic Weight Averaging
trainer = Trainer(stochastic_weight_avg=True)
----------
Auto scaling of batch size
--------------------------
Auto scaling of batch size may be enabled to find the largest batch size that fits into

View File

@ -106,6 +106,7 @@ Lightning has a few built-in callbacks.
ModelPruning
ProgressBar
ProgressBarBase
QuantizationAwareTraining
StochasticWeightAveraging
----------

View File

@ -111,6 +111,7 @@ PyTorch Lightning Documentation
common/single_gpu
advanced/sequences
advanced/training_tricks
advanced/pruning_quantization
advanced/transfer_learning
advanced/tpu
advanced/cluster

View File

@ -83,15 +83,6 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:
class QuantizationAwareTraining(Callback):
"""
Quantization allows speeding up inference and decreasing memory requirements by performing computations
and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information see
`Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>_`
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
"""
OBSERVER_TYPES = ('histogram', 'average')
def __init__(
@ -103,31 +94,49 @@ class QuantizationAwareTraining(Callback):
input_compatible: bool = True,
) -> None:
"""
Quantization allows speeding up inference and decreasing memory requirements
by performing computations and storing tensors at lower bitwidths
(such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
Args:
qconfig: define quantization configuration see: `torch.quantization.QConfig
<https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>_`
or use pre-defined: 'fbgemm' for server inference and 'qnnpack' for mobile inference
qconfig: quantization configuration:
- 'fbgemm' for server inference.
- 'qnnpack' for mobile inference.
- a custom `torch.quantization.QConfig <https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
and ``HistogramObserver`` as "histogram" which is more computationally expensive
collect_quantization: count or custom function to collect quantization statistics
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
- with default ``None`` the quantization observer is called each module forward,
typical use-case can be collecting extended statistic when user uses image/data augmentation
- custom call count to set a fixed number of calls, starting from the beginning
- custom ``Callable`` function with single trainer argument,
see example when you limit call only for last epoch::
collect_quantization: count or custom function to collect quantization statistics:
def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)
- ``None`` (deafult). The quantization observer is called in each module forward
(useful for collecting extended statistic when useing image/data augmentation).
- ``int``. Use to set a fixed number of calls, starting from the beginning.
- ``Callable``. Custom function with single trainer argument.
See this example to trigger only the last epoch:
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
.. code-block:: python
def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
modules_to_fuse: allows you fuse a few layers together as shown in
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
modules_to_fuse: allows you fuse a few layers together as shown in `diagram
<https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>_`
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
but break compatibility to torchscript
"""
but break compatibility to torchscript.
""" # noqa: E501
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
raise MisconfigurationException(

View File

@ -63,6 +63,12 @@ class StochasticWeightAveraging(Callback):
.. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers.
SWA can easily be activated directly from the Trainer as follow:
.. code-block:: python
Trainer(stochastic_weight_avg=True)
Arguments:
swa_epoch_start: If provided as int, the procedure will start from