Quant as optional step (#8464)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5452590872
commit
b7dbcc3e13
|
@ -177,12 +177,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480))
|
||||
|
||||
|
||||
- Added `quantize_on_fit_end` argument to `QuantizationAwareTraining` ([#8464](https://github.com/PyTorchLightning/pytorch-lightning/pull/8464))
|
||||
|
||||
|
||||
- Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226))
|
||||
|
||||
|
||||
- Added support for `devices` flag to Trainer ([#8440](https://github.com/PyTorchLightning/pytorch-lightning/pull/8440))
|
||||
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
|
|
@ -82,6 +82,54 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:
|
|||
|
||||
|
||||
class QuantizationAwareTraining(Callback):
|
||||
"""
|
||||
Quantization allows speeding up inference and decreasing memory requirements
|
||||
by performing computations and storing tensors at lower bitwidths
|
||||
(such as INT8 or FLOAT16) than floating point precision.
|
||||
We use native PyTorch API so for more information
|
||||
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.
|
||||
|
||||
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
qconfig: quantization configuration:
|
||||
|
||||
- 'fbgemm' for server inference.
|
||||
- 'qnnpack' for mobile inference.
|
||||
- a custom `torch.quantization.QConfig
|
||||
<https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.
|
||||
|
||||
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
|
||||
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
|
||||
|
||||
collect_quantization: count or custom function to collect quantization statistics:
|
||||
|
||||
- ``None`` (deafult). The quantization observer is called in each module forward
|
||||
(useful for collecting extended statistic when useing image/data augmentation).
|
||||
- ``int``. Use to set a fixed number of calls, starting from the beginning.
|
||||
- ``Callable``. Custom function with single trainer argument.
|
||||
See this example to trigger only the last epoch:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def custom_trigger_last(trainer):
|
||||
return trainer.current_epoch == (trainer.max_epochs - 1)
|
||||
|
||||
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
|
||||
|
||||
modules_to_fuse: allows you fuse a few layers together as shown in
|
||||
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
|
||||
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
|
||||
|
||||
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
|
||||
but break compatibility to torchscript and export with ``torch.save``.
|
||||
|
||||
quantize_on_fit_end: perform the quantization in `on_fit_end`.
|
||||
Note that once converted, the model cannot be put in training mode again.
|
||||
|
||||
"""
|
||||
OBSERVER_TYPES = ('histogram', 'average')
|
||||
|
||||
def __init__(
|
||||
|
@ -91,51 +139,8 @@ class QuantizationAwareTraining(Callback):
|
|||
collect_quantization: Optional[Union[int, Callable]] = None,
|
||||
modules_to_fuse: Optional[Sequence] = None,
|
||||
input_compatible: bool = True,
|
||||
quantize_on_fit_end: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Quantization allows speeding up inference and decreasing memory requirements
|
||||
by performing computations and storing tensors at lower bitwidths
|
||||
(such as INT8 or FLOAT16) than floating point precision.
|
||||
We use native PyTorch API so for more information
|
||||
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.
|
||||
|
||||
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
|
||||
|
||||
|
||||
Args:
|
||||
|
||||
qconfig: quantization configuration:
|
||||
|
||||
- 'fbgemm' for server inference.
|
||||
- 'qnnpack' for mobile inference.
|
||||
- a custom `torch.quantization.QConfig <https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.
|
||||
|
||||
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
|
||||
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
|
||||
|
||||
collect_quantization: count or custom function to collect quantization statistics:
|
||||
|
||||
- ``None`` (deafult). The quantization observer is called in each module forward
|
||||
(useful for collecting extended statistic when useing image/data augmentation).
|
||||
- ``int``. Use to set a fixed number of calls, starting from the beginning.
|
||||
- ``Callable``. Custom function with single trainer argument.
|
||||
See this example to trigger only the last epoch:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def custom_trigger_last(trainer):
|
||||
return trainer.current_epoch == (trainer.max_epochs - 1)
|
||||
|
||||
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
|
||||
|
||||
modules_to_fuse: allows you fuse a few layers together as shown in
|
||||
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
|
||||
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
|
||||
|
||||
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
|
||||
but break compatibility to torchscript.
|
||||
|
||||
""" # noqa: E501
|
||||
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
|
||||
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
|
||||
raise MisconfigurationException(
|
||||
|
@ -157,6 +162,7 @@ class QuantizationAwareTraining(Callback):
|
|||
|
||||
self.modules_to_fuse = modules_to_fuse
|
||||
self._input_compatible = input_compatible
|
||||
self._convert_on_fit_end = quantize_on_fit_end
|
||||
self._forward_calls = 0
|
||||
|
||||
def _check_feasible_fuse(self, model):
|
||||
|
@ -199,6 +205,9 @@ class QuantizationAwareTraining(Callback):
|
|||
torch.quantization.prepare_qat(pl_module, inplace=True)
|
||||
|
||||
def on_fit_end(self, trainer, pl_module):
|
||||
if not self._convert_on_fit_end:
|
||||
pl_module.forward = self.__module_forward
|
||||
return
|
||||
pl_module.eval()
|
||||
# Convert the observed model to a quantized model. This does several things:
|
||||
# quantizes the weights, computes and stores the scale and bias value to be
|
||||
|
|
|
@ -1228,8 +1228,10 @@ class Trainer(
|
|||
return output
|
||||
|
||||
def _parse_devices(
|
||||
self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, tpu_cores: Optional[Union[List[int],
|
||||
str, int]]
|
||||
self,
|
||||
gpus: Optional[Union[List[int], str, int]],
|
||||
auto_select_gpus: bool,
|
||||
tpu_cores: Optional[Union[List[int], str, int]],
|
||||
) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]:
|
||||
if auto_select_gpus and isinstance(gpus, int):
|
||||
gpus = pick_multiple_gpus(gpus)
|
||||
|
|
|
@ -28,15 +28,16 @@ from tests.helpers.simple_models import RegressionModel
|
|||
|
||||
@pytest.mark.parametrize("observe", ['average', 'histogram'])
|
||||
@pytest.mark.parametrize("fuse", [True, False])
|
||||
@pytest.mark.parametrize("convert", [True, False])
|
||||
@RunIf(quantization=True)
|
||||
def test_quantization(tmpdir, observe: str, fuse: bool):
|
||||
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
|
||||
"""Parity test for quant model"""
|
||||
seed_everything(42)
|
||||
dm = RegressDataModule()
|
||||
trainer_args = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=10,
|
||||
gpus=1 if torch.cuda.is_available() else None,
|
||||
max_epochs=7,
|
||||
gpus=int(torch.cuda.is_available()),
|
||||
)
|
||||
model = RegressionModel()
|
||||
qmodel = copy.deepcopy(model)
|
||||
|
@ -47,20 +48,33 @@ def test_quantization(tmpdir, observe: str, fuse: bool):
|
|||
org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()]))
|
||||
|
||||
fusing_layers = [(f'layer_{i}', f'layer_{i}a') for i in range(3)] if fuse else None
|
||||
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers)
|
||||
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
|
||||
trainer = Trainer(callbacks=[qcb], **trainer_args)
|
||||
trainer.fit(qmodel, datamodule=dm)
|
||||
|
||||
quant_calls = qcb._forward_calls
|
||||
assert quant_calls == qcb._forward_calls
|
||||
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
|
||||
# test that the test score is almost the same as with pure training
|
||||
assert torch.allclose(org_score, quant_score, atol=0.45)
|
||||
model_path = trainer.checkpoint_callback.best_model_path
|
||||
|
||||
trainer_args.update(dict(max_epochs=1, checkpoint_callback=False))
|
||||
if not convert:
|
||||
trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
|
||||
trainer.fit(qmodel, datamodule=dm)
|
||||
qmodel.eval()
|
||||
torch.quantization.convert(qmodel, inplace=True)
|
||||
|
||||
quant_size = qmodel.model_size
|
||||
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
|
||||
# test that the trained model is smaller then initial
|
||||
size_ratio = quant_size / org_size
|
||||
assert size_ratio < 0.65
|
||||
# test that the test score is almost the same as with pure training
|
||||
assert torch.allclose(org_score, quant_score, atol=0.45)
|
||||
|
||||
# todo: make it work also with strict loading
|
||||
qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
|
||||
quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()]))
|
||||
assert torch.allclose(org_score, quant2_score, atol=0.45)
|
||||
|
||||
|
||||
@RunIf(quantization=True)
|
||||
|
|
Loading…
Reference in New Issue