diff --git a/CHANGELOG.md b/CHANGELOG.md index 96ff7bb08b..6a68604772 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -177,12 +177,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480)) +- Added `quantize_on_fit_end` argument to `QuantizationAwareTraining` ([#8464](https://github.com/PyTorchLightning/pytorch-lightning/pull/8464)) + + - Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226)) - Added support for `devices` flag to Trainer ([#8440](https://github.com/PyTorchLightning/pytorch-lightning/pull/8440)) + ### Changed diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 9153fbacbf..b6df2b0d4e 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -82,6 +82,54 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool: class QuantizationAwareTraining(Callback): + """ + Quantization allows speeding up inference and decreasing memory requirements + by performing computations and storing tensors at lower bitwidths + (such as INT8 or FLOAT16) than floating point precision. + We use native PyTorch API so for more information + see `Quantization `_. + + .. warning:: ``QuantizationAwareTraining`` is in beta and subject to change. + + + Args: + + qconfig: quantization configuration: + + - 'fbgemm' for server inference. + - 'qnnpack' for mobile inference. + - a custom `torch.quantization.QConfig + `_. + + observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default) + and ``HistogramObserver`` as "histogram" which is more computationally expensive. + + collect_quantization: count or custom function to collect quantization statistics: + + - ``None`` (deafult). The quantization observer is called in each module forward + (useful for collecting extended statistic when useing image/data augmentation). + - ``int``. Use to set a fixed number of calls, starting from the beginning. + - ``Callable``. Custom function with single trainer argument. + See this example to trigger only the last epoch: + + .. code-block:: python + + def custom_trigger_last(trainer): + return trainer.current_epoch == (trainer.max_epochs - 1) + + QuantizationAwareTraining(collect_quantization=custom_trigger_last) + + modules_to_fuse: allows you fuse a few layers together as shown in + `diagram `_ + to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286. + + input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model, + but break compatibility to torchscript and export with ``torch.save``. + + quantize_on_fit_end: perform the quantization in `on_fit_end`. + Note that once converted, the model cannot be put in training mode again. + + """ OBSERVER_TYPES = ('histogram', 'average') def __init__( @@ -91,51 +139,8 @@ class QuantizationAwareTraining(Callback): collect_quantization: Optional[Union[int, Callable]] = None, modules_to_fuse: Optional[Sequence] = None, input_compatible: bool = True, + quantize_on_fit_end: bool = True, ) -> None: - """ - Quantization allows speeding up inference and decreasing memory requirements - by performing computations and storing tensors at lower bitwidths - (such as INT8 or FLOAT16) than floating point precision. - We use native PyTorch API so for more information - see `Quantization `_. - - .. warning:: ``QuantizationAwareTraining`` is in beta and subject to change. - - - Args: - - qconfig: quantization configuration: - - - 'fbgemm' for server inference. - - 'qnnpack' for mobile inference. - - a custom `torch.quantization.QConfig `_. - - observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default) - and ``HistogramObserver`` as "histogram" which is more computationally expensive. - - collect_quantization: count or custom function to collect quantization statistics: - - - ``None`` (deafult). The quantization observer is called in each module forward - (useful for collecting extended statistic when useing image/data augmentation). - - ``int``. Use to set a fixed number of calls, starting from the beginning. - - ``Callable``. Custom function with single trainer argument. - See this example to trigger only the last epoch: - - .. code-block:: python - - def custom_trigger_last(trainer): - return trainer.current_epoch == (trainer.max_epochs - 1) - - QuantizationAwareTraining(collect_quantization=custom_trigger_last) - - modules_to_fuse: allows you fuse a few layers together as shown in - `diagram `_ - to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286. - - input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model, - but break compatibility to torchscript. - - """ # noqa: E501 _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines if not isinstance(qconfig, QConfig) and not _valid_qconf_str: raise MisconfigurationException( @@ -157,6 +162,7 @@ class QuantizationAwareTraining(Callback): self.modules_to_fuse = modules_to_fuse self._input_compatible = input_compatible + self._convert_on_fit_end = quantize_on_fit_end self._forward_calls = 0 def _check_feasible_fuse(self, model): @@ -199,6 +205,9 @@ class QuantizationAwareTraining(Callback): torch.quantization.prepare_qat(pl_module, inplace=True) def on_fit_end(self, trainer, pl_module): + if not self._convert_on_fit_end: + pl_module.forward = self.__module_forward + return pl_module.eval() # Convert the observed model to a quantized model. This does several things: # quantizes the weights, computes and stores the scale and bias value to be diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ef6b74f1de..eb2e9ae4b1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1228,8 +1228,10 @@ class Trainer( return output def _parse_devices( - self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, tpu_cores: Optional[Union[List[int], - str, int]] + self, + gpus: Optional[Union[List[int], str, int]], + auto_select_gpus: bool, + tpu_cores: Optional[Union[List[int], str, int]], ) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]: if auto_select_gpus and isinstance(gpus, int): gpus = pick_multiple_gpus(gpus) diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index 8627acd23f..fafcdfbc5b 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -28,15 +28,16 @@ from tests.helpers.simple_models import RegressionModel @pytest.mark.parametrize("observe", ['average', 'histogram']) @pytest.mark.parametrize("fuse", [True, False]) +@pytest.mark.parametrize("convert", [True, False]) @RunIf(quantization=True) -def test_quantization(tmpdir, observe: str, fuse: bool): +def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): """Parity test for quant model""" seed_everything(42) dm = RegressDataModule() trainer_args = dict( default_root_dir=tmpdir, - max_epochs=10, - gpus=1 if torch.cuda.is_available() else None, + max_epochs=7, + gpus=int(torch.cuda.is_available()), ) model = RegressionModel() qmodel = copy.deepcopy(model) @@ -47,20 +48,33 @@ def test_quantization(tmpdir, observe: str, fuse: bool): org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()])) fusing_layers = [(f'layer_{i}', f'layer_{i}a') for i in range(3)] if fuse else None - qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers) + qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert) trainer = Trainer(callbacks=[qcb], **trainer_args) trainer.fit(qmodel, datamodule=dm) quant_calls = qcb._forward_calls assert quant_calls == qcb._forward_calls + quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()])) + # test that the test score is almost the same as with pure training + assert torch.allclose(org_score, quant_score, atol=0.45) + model_path = trainer.checkpoint_callback.best_model_path + + trainer_args.update(dict(max_epochs=1, checkpoint_callback=False)) + if not convert: + trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args) + trainer.fit(qmodel, datamodule=dm) + qmodel.eval() + torch.quantization.convert(qmodel, inplace=True) quant_size = qmodel.model_size - quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()])) # test that the trained model is smaller then initial size_ratio = quant_size / org_size assert size_ratio < 0.65 - # test that the test score is almost the same as with pure training - assert torch.allclose(org_score, quant_score, atol=0.45) + + # todo: make it work also with strict loading + qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False) + quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()])) + assert torch.allclose(org_score, quant2_score, atol=0.45) @RunIf(quantization=True)