Quant as optional step (#8464)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jirka Borovec 2021-07-22 14:44:27 +02:00 committed by GitHub
parent 5452590872
commit b7dbcc3e13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 53 deletions

View File

@ -177,12 +177,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480))
- Added `quantize_on_fit_end` argument to `QuantizationAwareTraining` ([#8464](https://github.com/PyTorchLightning/pytorch-lightning/pull/8464))
- Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226))
- Added support for `devices` flag to Trainer ([#8440](https://github.com/PyTorchLightning/pytorch-lightning/pull/8440))
### Changed

View File

@ -82,6 +82,54 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:
class QuantizationAwareTraining(Callback):
"""
Quantization allows speeding up inference and decreasing memory requirements
by performing computations and storing tensors at lower bitwidths
(such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
Args:
qconfig: quantization configuration:
- 'fbgemm' for server inference.
- 'qnnpack' for mobile inference.
- a custom `torch.quantization.QConfig
<https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
collect_quantization: count or custom function to collect quantization statistics:
- ``None`` (deafult). The quantization observer is called in each module forward
(useful for collecting extended statistic when useing image/data augmentation).
- ``int``. Use to set a fixed number of calls, starting from the beginning.
- ``Callable``. Custom function with single trainer argument.
See this example to trigger only the last epoch:
.. code-block:: python
def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
modules_to_fuse: allows you fuse a few layers together as shown in
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
but break compatibility to torchscript and export with ``torch.save``.
quantize_on_fit_end: perform the quantization in `on_fit_end`.
Note that once converted, the model cannot be put in training mode again.
"""
OBSERVER_TYPES = ('histogram', 'average')
def __init__(
@ -91,51 +139,8 @@ class QuantizationAwareTraining(Callback):
collect_quantization: Optional[Union[int, Callable]] = None,
modules_to_fuse: Optional[Sequence] = None,
input_compatible: bool = True,
quantize_on_fit_end: bool = True,
) -> None:
"""
Quantization allows speeding up inference and decreasing memory requirements
by performing computations and storing tensors at lower bitwidths
(such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.
.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
Args:
qconfig: quantization configuration:
- 'fbgemm' for server inference.
- 'qnnpack' for mobile inference.
- a custom `torch.quantization.QConfig <https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
collect_quantization: count or custom function to collect quantization statistics:
- ``None`` (deafult). The quantization observer is called in each module forward
(useful for collecting extended statistic when useing image/data augmentation).
- ``int``. Use to set a fixed number of calls, starting from the beginning.
- ``Callable``. Custom function with single trainer argument.
See this example to trigger only the last epoch:
.. code-block:: python
def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)
QuantizationAwareTraining(collect_quantization=custom_trigger_last)
modules_to_fuse: allows you fuse a few layers together as shown in
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
but break compatibility to torchscript.
""" # noqa: E501
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
raise MisconfigurationException(
@ -157,6 +162,7 @@ class QuantizationAwareTraining(Callback):
self.modules_to_fuse = modules_to_fuse
self._input_compatible = input_compatible
self._convert_on_fit_end = quantize_on_fit_end
self._forward_calls = 0
def _check_feasible_fuse(self, model):
@ -199,6 +205,9 @@ class QuantizationAwareTraining(Callback):
torch.quantization.prepare_qat(pl_module, inplace=True)
def on_fit_end(self, trainer, pl_module):
if not self._convert_on_fit_end:
pl_module.forward = self.__module_forward
return
pl_module.eval()
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be

View File

@ -1228,8 +1228,10 @@ class Trainer(
return output
def _parse_devices(
self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, tpu_cores: Optional[Union[List[int],
str, int]]
self,
gpus: Optional[Union[List[int], str, int]],
auto_select_gpus: bool,
tpu_cores: Optional[Union[List[int], str, int]],
) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]:
if auto_select_gpus and isinstance(gpus, int):
gpus = pick_multiple_gpus(gpus)

View File

@ -28,15 +28,16 @@ from tests.helpers.simple_models import RegressionModel
@pytest.mark.parametrize("observe", ['average', 'histogram'])
@pytest.mark.parametrize("fuse", [True, False])
@pytest.mark.parametrize("convert", [True, False])
@RunIf(quantization=True)
def test_quantization(tmpdir, observe: str, fuse: bool):
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
"""Parity test for quant model"""
seed_everything(42)
dm = RegressDataModule()
trainer_args = dict(
default_root_dir=tmpdir,
max_epochs=10,
gpus=1 if torch.cuda.is_available() else None,
max_epochs=7,
gpus=int(torch.cuda.is_available()),
)
model = RegressionModel()
qmodel = copy.deepcopy(model)
@ -47,20 +48,33 @@ def test_quantization(tmpdir, observe: str, fuse: bool):
org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()]))
fusing_layers = [(f'layer_{i}', f'layer_{i}a') for i in range(3)] if fuse else None
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers)
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
trainer = Trainer(callbacks=[qcb], **trainer_args)
trainer.fit(qmodel, datamodule=dm)
quant_calls = qcb._forward_calls
assert quant_calls == qcb._forward_calls
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
# test that the test score is almost the same as with pure training
assert torch.allclose(org_score, quant_score, atol=0.45)
model_path = trainer.checkpoint_callback.best_model_path
trainer_args.update(dict(max_epochs=1, checkpoint_callback=False))
if not convert:
trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
trainer.fit(qmodel, datamodule=dm)
qmodel.eval()
torch.quantization.convert(qmodel, inplace=True)
quant_size = qmodel.model_size
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
# test that the trained model is smaller then initial
size_ratio = quant_size / org_size
assert size_ratio < 0.65
# test that the test score is almost the same as with pure training
assert torch.allclose(org_score, quant_score, atol=0.45)
# todo: make it work also with strict loading
qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()]))
assert torch.allclose(org_score, quant2_score, atol=0.45)
@RunIf(quantization=True)