From cfb2d877656064fc5bf016b20568ede6ba27113b Mon Sep 17 00:00:00 2001 From: manipopopo Date: Mon, 25 Oct 2021 23:46:09 +0800 Subject: [PATCH] Disable quantization aware training observers (#8540) Co-authored-by: Jirka Borovec Co-authored-by: tchaton Co-authored-by: rohitgr7 --- CHANGELOG.md | 5 +- pytorch_lightning/callbacks/quantization.py | 104 ++++++++++++++++++- tests/callbacks/test_quantization.py | 106 +++++++++++++++++++- 3 files changed, 207 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cd333d385..c88835aef4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -328,13 +328,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `pytorch_lightning.utilities.grads.grad_norm` now raises an exception if parameter `norm_type <= 0` ([#9765](https://github.com/PyTorchLightning/pytorch-lightning/pull/9765)) - - Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896)) - Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) +- Disable quantization aware training observers by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540)) + ### Deprecated @@ -409,6 +410,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924)) + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) @@ -611,7 +613,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `LearningRateMonitor` logging with multiple param groups optimizer with no scheduler ([#10044](https://github.com/PyTorchLightning/pytorch-lightning/pull/10044)) - - Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764)) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 564e240716..bf0088575e 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -16,10 +16,20 @@ Quantization ^^^^^^^^^^^^ """ +import copy import functools -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Union import torch +from torch import Tensor + +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 + +if _TORCH_GREATER_EQUAL_1_8: + from torch.quantization import FakeQuantizeBase +else: + # For torch 1.6 and 1.7. + from torch.quantization import FakeQuantize as FakeQuantizeBase import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback @@ -126,11 +136,25 @@ class QuantizationAwareTraining(Callback): quantize_on_fit_end: perform the quantization in `on_fit_end`. Note that once converted, the model cannot be put in training mode again. + observer_enabled_stages: allow fake-quantization modules' observers to do calibration during provided stages: + + - ``'train'``: the observers can do calibration during training. + - ``'validate'``: the observers can do calibration during validating. + Note that we don't disable observers during the sanity check as the model hasn't been calibrated with + training data yet. After the sanity check, the fake-quantization modules are restored to initial states. + - ``'test'``: the observers can do calibration during testing. + - ``'predict'``: the observers can do calibration during predicting. + + Note that we only handle observers belonging to fake-quantization modules. When ``qconfig`` is a ``str`` and + ``observer_type`` is ``'histogram'``, the observers won't belong to any fake-quantization modules and will + not be controlled by the callback. + .. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training .. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig """ OBSERVER_TYPES = ("histogram", "average") + OBSERVER_STAGES = ("train", "validate", "test", "predict") def __init__( self, @@ -140,6 +164,7 @@ class QuantizationAwareTraining(Callback): modules_to_fuse: Optional[Sequence] = None, input_compatible: bool = True, quantize_on_fit_end: bool = True, + observer_enabled_stages: Sequence[str] = ("train",), ) -> None: _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines if not isinstance(qconfig, QConfig) and not _valid_qconf_str: @@ -163,9 +188,20 @@ class QuantizationAwareTraining(Callback): self.modules_to_fuse = modules_to_fuse self._input_compatible = input_compatible self._convert_on_fit_end = quantize_on_fit_end - self._forward_calls = 0 - def _check_feasible_fuse(self, model): + observer_enabled_stages = set(observer_enabled_stages) + unsupported_stages = observer_enabled_stages - set(self.OBSERVER_STAGES) + if unsupported_stages: + raise MisconfigurationException( + f'Unsupported stages "{tuple(sorted(unsupported_stages))}", allowed are {self.OBSERVER_STAGES}.' + ) + self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages + + self._forward_calls = 0 + self._fake_quant_to_initial_state_dict = {} + self._last_fake_quant_to_observer_enabled = {} + + def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool: if not self.modules_to_fuse: return False for group in self.modules_to_fuse: @@ -175,7 +211,20 @@ class QuantizationAwareTraining(Callback): ) return True - def on_fit_start(self, trainer, pl_module): + def _collect_observer_enabled(self) -> Dict[FakeQuantizeBase, Tensor]: + return { + fake_quant: fake_quant.observer_enabled.clone() for fake_quant in self._fake_quant_to_initial_state_dict + } + + def _disable_observer(self, pl_module: "pl.LightningModule") -> None: + self._last_fake_quant_to_observer_enabled = self._collect_observer_enabled() + pl_module.apply(torch.quantization.disable_observer) + + def _restore_last_observer_enabled(self) -> None: + for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items(): + fake_quant.observer_enabled.copy_(observer_enabled) + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # QuantStub converts tensors from floating point to quantized pl_module.quant = torch.quantization.QuantStub() # DeQuantStub converts tensors from quantized to floating point @@ -209,7 +258,12 @@ class QuantizationAwareTraining(Callback): # the model that will observe weight and activation tensors during calibration. torch.quantization.prepare_qat(pl_module, inplace=True) - def on_fit_end(self, trainer, pl_module): + fake_quants = tuple(module for module in pl_module.modules() if isinstance(module, FakeQuantizeBase)) + self._fake_quant_to_initial_state_dict = { + fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants + } + + def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self._convert_on_fit_end: pl_module.forward = self.__module_forward return @@ -224,3 +278,43 @@ class QuantizationAwareTraining(Callback): pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) else: pl_module.forward = self.__module_forward + + def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "train" in self._observer_disabled_stages: + self._disable_observer(pl_module) + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "train" in self._observer_disabled_stages: + self._restore_last_observer_enabled() + + def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "validate" in self._observer_disabled_stages and not trainer.sanity_checking: + # ``torch.quantization.MovingAveragePerChannelMinMaxObserver`` and ``torch.quantization.HistogramObserver`` + # need to see at least one batch to infer the shapes of quantization ``scale`` and ``zero_point``. So we + # don't disable observers during the sanity check so that they can infer the shapes of quantization + # parameters with validation data. + self._disable_observer(pl_module) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "validate" in self._observer_disabled_stages: + if trainer.sanity_checking: + for fake_quant, state_dict in self._fake_quant_to_initial_state_dict.items(): + fake_quant.load_state_dict(state_dict) + else: + self._restore_last_observer_enabled() + + def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "test" in self._observer_disabled_stages: + self._disable_observer(pl_module) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "test" in self._observer_disabled_stages: + self._restore_last_observer_enabled() + + def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "predict" in self._observer_disabled_stages: + self._disable_observer(pl_module) + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if "predict" in self._observer_disabled_stages: + self._restore_last_observer_enabled() diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index f548d4d98a..ec2bb66110 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -21,11 +21,19 @@ from torchmetrics.functional import mean_absolute_percentage_error as mape from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import QuantizationAwareTraining from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.memory import get_model_size_mb +from tests.helpers.boring_model import RandomDataset from tests.helpers.datamodules import RegressDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import RegressionModel +if _TORCH_GREATER_EQUAL_1_8: + from torch.quantization import FakeQuantizeBase +else: + # For torch 1.6 and 1.7. + from torch.quantization import FakeQuantize as FakeQuantizeBase + @pytest.mark.parametrize("observe", ["average", "histogram"]) @pytest.mark.parametrize("fuse", [True, False]) @@ -45,7 +53,12 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()])) fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None - qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert) + qcb = QuantizationAwareTraining( + observer_type=observe, + modules_to_fuse=fusing_layers, + quantize_on_fit_end=convert, + observer_enabled_stages=("train", "validate"), + ) trainer = Trainer(callbacks=[qcb], **trainer_args) trainer.fit(qmodel, datamodule=dm) @@ -105,6 +118,9 @@ def test_quantization_exceptions(tmpdir): with pytest.raises(MisconfigurationException, match="Unsupported `collect_quantization`"): QuantizationAwareTraining(collect_quantization=1.2) + with pytest.raises(MisconfigurationException, match="Unsupported stages"): + QuantizationAwareTraining(observer_enabled_stages=("abc",)) + fusing_layers = [(f"layers.mlp_{i}", f"layers.NONE-mlp_{i}a") for i in range(3)] qcb = QuantizationAwareTraining(modules_to_fuse=fusing_layers) trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1) @@ -140,3 +156,91 @@ def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], e trainer.fit(qmodel, datamodule=dm) assert qcb._forward_calls == expected_count + + +def _get_observer_enabled(fake_quant: FakeQuantizeBase): + # ``torch.quantization.FakeQuantize`` checks ``observer_enabled[0] == 1``. + return fake_quant.observer_enabled[0] == 1 + + +@pytest.mark.parametrize( + "observer_enabled_stages", + [("train", "validate", "test", "predict"), ("train",), ("validate",), ("test",), ("predict",), ()], +) +@RunIf(quantization=True) +def test_quantization_disable_observers(tmpdir, observer_enabled_stages): + """Test disabling observers.""" + qmodel = RegressionModel() + qcb = QuantizationAwareTraining(observer_enabled_stages=observer_enabled_stages) + trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir) + + # Quantize qmodel. + qcb.on_fit_start(trainer, qmodel) + fake_quants = list(module for module in qmodel.modules() if isinstance(module, FakeQuantizeBase)) + # Disable some of observers before fitting. + for fake_quant in fake_quants[::2]: + fake_quant.disable_observer() + + for stage, on_stage_start, on_stage_end in [ + ("train", qcb.on_train_start, qcb.on_train_end), + ("validate", qcb.on_validation_start, qcb.on_validation_end), + ("test", qcb.on_test_start, qcb.on_test_end), + ("predict", qcb.on_predict_start, qcb.on_predict_end), + ]: + before_stage_observer_enabled = torch.as_tensor(list(map(_get_observer_enabled, fake_quants))) + + on_stage_start(trainer, qmodel) + expected_stage_observer_enabled = torch.as_tensor( + before_stage_observer_enabled if stage in observer_enabled_stages else [False] * len(fake_quants) + ) + assert torch.equal( + torch.as_tensor(list(map(_get_observer_enabled, fake_quants))), expected_stage_observer_enabled + ) + + on_stage_end(trainer, qmodel) + assert torch.equal( + torch.as_tensor(list(map(_get_observer_enabled, fake_quants))), before_stage_observer_enabled + ) + + +@RunIf(quantization=True) +def test_quantization_val_test_predict(tmpdir): + """Test the default quantization aware training not affected by validating, testing and predicting.""" + seed_everything(42) + num_features = 16 + dm = RegressDataModule(num_features=num_features) + qmodel = RegressionModel() + + val_test_predict_qmodel = copy.deepcopy(qmodel) + trainer = Trainer( + callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)], + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + val_check_interval=1, + num_sanity_val_steps=1, + max_epochs=4, + ) + trainer.fit(val_test_predict_qmodel, datamodule=dm) + trainer.validate(model=val_test_predict_qmodel, verbose=False) + trainer.test(model=val_test_predict_qmodel, verbose=False) + trainer.predict( + model=val_test_predict_qmodel, dataloaders=[torch.utils.data.DataLoader(RandomDataset(num_features, 16))] + ) + + expected_qmodel = copy.deepcopy(qmodel) + # No validation in ``expected_qmodel`` fitting. + Trainer( + callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)], + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=4, + ).fit(expected_qmodel, datamodule=dm) + + expected_state_dict = expected_qmodel.state_dict() + for key, value in val_test_predict_qmodel.state_dict().items(): + expected_value = expected_state_dict[key] + assert torch.allclose(value, expected_value)