Disable quantization aware training observers (#8540)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
manipopopo 2021-10-25 23:46:09 +08:00 committed by GitHub
parent f8a7f3fde0
commit cfb2d87765
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 207 additions and 8 deletions

View File

@ -328,13 +328,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `pytorch_lightning.utilities.grads.grad_norm` now raises an exception if parameter `norm_type <= 0` ([#9765](https://github.com/PyTorchLightning/pytorch-lightning/pull/9765)) - `pytorch_lightning.utilities.grads.grad_norm` now raises an exception if parameter `norm_type <= 0` ([#9765](https://github.com/PyTorchLightning/pytorch-lightning/pull/9765))
- Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896)) - Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896))
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) - Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
- Disable quantization aware training observers by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))
### Deprecated ### Deprecated
@ -409,6 +410,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924)) - Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))
### Removed ### Removed
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
@ -611,7 +613,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `LearningRateMonitor` logging with multiple param groups optimizer with no scheduler ([#10044](https://github.com/PyTorchLightning/pytorch-lightning/pull/10044)) - Fixed `LearningRateMonitor` logging with multiple param groups optimizer with no scheduler ([#10044](https://github.com/PyTorchLightning/pytorch-lightning/pull/10044))
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764)) - Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))

View File

@ -16,10 +16,20 @@ Quantization
^^^^^^^^^^^^ ^^^^^^^^^^^^
""" """
import copy
import functools import functools
from typing import Any, Callable, Optional, Sequence, Union from typing import Any, Callable, Dict, Optional, Sequence, Union
import torch import torch
from torch import Tensor
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
if _TORCH_GREATER_EQUAL_1_8:
from torch.quantization import FakeQuantizeBase
else:
# For torch 1.6 and 1.7.
from torch.quantization import FakeQuantize as FakeQuantizeBase
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.base import Callback
@ -126,11 +136,25 @@ class QuantizationAwareTraining(Callback):
quantize_on_fit_end: perform the quantization in `on_fit_end`. quantize_on_fit_end: perform the quantization in `on_fit_end`.
Note that once converted, the model cannot be put in training mode again. Note that once converted, the model cannot be put in training mode again.
observer_enabled_stages: allow fake-quantization modules' observers to do calibration during provided stages:
- ``'train'``: the observers can do calibration during training.
- ``'validate'``: the observers can do calibration during validating.
Note that we don't disable observers during the sanity check as the model hasn't been calibrated with
training data yet. After the sanity check, the fake-quantization modules are restored to initial states.
- ``'test'``: the observers can do calibration during testing.
- ``'predict'``: the observers can do calibration during predicting.
Note that we only handle observers belonging to fake-quantization modules. When ``qconfig`` is a ``str`` and
``observer_type`` is ``'histogram'``, the observers won't belong to any fake-quantization modules and will
not be controlled by the callback.
.. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training .. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training
.. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig .. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig
""" """
OBSERVER_TYPES = ("histogram", "average") OBSERVER_TYPES = ("histogram", "average")
OBSERVER_STAGES = ("train", "validate", "test", "predict")
def __init__( def __init__(
self, self,
@ -140,6 +164,7 @@ class QuantizationAwareTraining(Callback):
modules_to_fuse: Optional[Sequence] = None, modules_to_fuse: Optional[Sequence] = None,
input_compatible: bool = True, input_compatible: bool = True,
quantize_on_fit_end: bool = True, quantize_on_fit_end: bool = True,
observer_enabled_stages: Sequence[str] = ("train",),
) -> None: ) -> None:
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
if not isinstance(qconfig, QConfig) and not _valid_qconf_str: if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
@ -163,9 +188,20 @@ class QuantizationAwareTraining(Callback):
self.modules_to_fuse = modules_to_fuse self.modules_to_fuse = modules_to_fuse
self._input_compatible = input_compatible self._input_compatible = input_compatible
self._convert_on_fit_end = quantize_on_fit_end self._convert_on_fit_end = quantize_on_fit_end
self._forward_calls = 0
def _check_feasible_fuse(self, model): observer_enabled_stages = set(observer_enabled_stages)
unsupported_stages = observer_enabled_stages - set(self.OBSERVER_STAGES)
if unsupported_stages:
raise MisconfigurationException(
f'Unsupported stages "{tuple(sorted(unsupported_stages))}", allowed are {self.OBSERVER_STAGES}.'
)
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
self._forward_calls = 0
self._fake_quant_to_initial_state_dict = {}
self._last_fake_quant_to_observer_enabled = {}
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
if not self.modules_to_fuse: if not self.modules_to_fuse:
return False return False
for group in self.modules_to_fuse: for group in self.modules_to_fuse:
@ -175,7 +211,20 @@ class QuantizationAwareTraining(Callback):
) )
return True return True
def on_fit_start(self, trainer, pl_module): def _collect_observer_enabled(self) -> Dict[FakeQuantizeBase, Tensor]:
return {
fake_quant: fake_quant.observer_enabled.clone() for fake_quant in self._fake_quant_to_initial_state_dict
}
def _disable_observer(self, pl_module: "pl.LightningModule") -> None:
self._last_fake_quant_to_observer_enabled = self._collect_observer_enabled()
pl_module.apply(torch.quantization.disable_observer)
def _restore_last_observer_enabled(self) -> None:
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
fake_quant.observer_enabled.copy_(observer_enabled)
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
# QuantStub converts tensors from floating point to quantized # QuantStub converts tensors from floating point to quantized
pl_module.quant = torch.quantization.QuantStub() pl_module.quant = torch.quantization.QuantStub()
# DeQuantStub converts tensors from quantized to floating point # DeQuantStub converts tensors from quantized to floating point
@ -209,7 +258,12 @@ class QuantizationAwareTraining(Callback):
# the model that will observe weight and activation tensors during calibration. # the model that will observe weight and activation tensors during calibration.
torch.quantization.prepare_qat(pl_module, inplace=True) torch.quantization.prepare_qat(pl_module, inplace=True)
def on_fit_end(self, trainer, pl_module): fake_quants = tuple(module for module in pl_module.modules() if isinstance(module, FakeQuantizeBase))
self._fake_quant_to_initial_state_dict = {
fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants
}
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._convert_on_fit_end: if not self._convert_on_fit_end:
pl_module.forward = self.__module_forward pl_module.forward = self.__module_forward
return return
@ -224,3 +278,43 @@ class QuantizationAwareTraining(Callback):
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
else: else:
pl_module.forward = self.__module_forward pl_module.forward = self.__module_forward
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "train" in self._observer_disabled_stages:
self._disable_observer(pl_module)
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "train" in self._observer_disabled_stages:
self._restore_last_observer_enabled()
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "validate" in self._observer_disabled_stages and not trainer.sanity_checking:
# ``torch.quantization.MovingAveragePerChannelMinMaxObserver`` and ``torch.quantization.HistogramObserver``
# need to see at least one batch to infer the shapes of quantization ``scale`` and ``zero_point``. So we
# don't disable observers during the sanity check so that they can infer the shapes of quantization
# parameters with validation data.
self._disable_observer(pl_module)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "validate" in self._observer_disabled_stages:
if trainer.sanity_checking:
for fake_quant, state_dict in self._fake_quant_to_initial_state_dict.items():
fake_quant.load_state_dict(state_dict)
else:
self._restore_last_observer_enabled()
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "test" in self._observer_disabled_stages:
self._disable_observer(pl_module)
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "test" in self._observer_disabled_stages:
self._restore_last_observer_enabled()
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "predict" in self._observer_disabled_stages:
self._disable_observer(pl_module)
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "predict" in self._observer_disabled_stages:
self._restore_last_observer_enabled()

View File

@ -21,11 +21,19 @@ from torchmetrics.functional import mean_absolute_percentage_error as mape
from pytorch_lightning import seed_everything, Trainer from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import QuantizationAwareTraining from pytorch_lightning.callbacks import QuantizationAwareTraining
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.memory import get_model_size_mb from pytorch_lightning.utilities.memory import get_model_size_mb
from tests.helpers.boring_model import RandomDataset
from tests.helpers.datamodules import RegressDataModule from tests.helpers.datamodules import RegressDataModule
from tests.helpers.runif import RunIf from tests.helpers.runif import RunIf
from tests.helpers.simple_models import RegressionModel from tests.helpers.simple_models import RegressionModel
if _TORCH_GREATER_EQUAL_1_8:
from torch.quantization import FakeQuantizeBase
else:
# For torch 1.6 and 1.7.
from torch.quantization import FakeQuantize as FakeQuantizeBase
@pytest.mark.parametrize("observe", ["average", "histogram"]) @pytest.mark.parametrize("observe", ["average", "histogram"])
@pytest.mark.parametrize("fuse", [True, False]) @pytest.mark.parametrize("fuse", [True, False])
@ -45,7 +53,12 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()])) org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()]))
fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert) qcb = QuantizationAwareTraining(
observer_type=observe,
modules_to_fuse=fusing_layers,
quantize_on_fit_end=convert,
observer_enabled_stages=("train", "validate"),
)
trainer = Trainer(callbacks=[qcb], **trainer_args) trainer = Trainer(callbacks=[qcb], **trainer_args)
trainer.fit(qmodel, datamodule=dm) trainer.fit(qmodel, datamodule=dm)
@ -105,6 +118,9 @@ def test_quantization_exceptions(tmpdir):
with pytest.raises(MisconfigurationException, match="Unsupported `collect_quantization`"): with pytest.raises(MisconfigurationException, match="Unsupported `collect_quantization`"):
QuantizationAwareTraining(collect_quantization=1.2) QuantizationAwareTraining(collect_quantization=1.2)
with pytest.raises(MisconfigurationException, match="Unsupported stages"):
QuantizationAwareTraining(observer_enabled_stages=("abc",))
fusing_layers = [(f"layers.mlp_{i}", f"layers.NONE-mlp_{i}a") for i in range(3)] fusing_layers = [(f"layers.mlp_{i}", f"layers.NONE-mlp_{i}a") for i in range(3)]
qcb = QuantizationAwareTraining(modules_to_fuse=fusing_layers) qcb = QuantizationAwareTraining(modules_to_fuse=fusing_layers)
trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1) trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1)
@ -140,3 +156,91 @@ def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], e
trainer.fit(qmodel, datamodule=dm) trainer.fit(qmodel, datamodule=dm)
assert qcb._forward_calls == expected_count assert qcb._forward_calls == expected_count
def _get_observer_enabled(fake_quant: FakeQuantizeBase):
# ``torch.quantization.FakeQuantize`` checks ``observer_enabled[0] == 1``.
return fake_quant.observer_enabled[0] == 1
@pytest.mark.parametrize(
"observer_enabled_stages",
[("train", "validate", "test", "predict"), ("train",), ("validate",), ("test",), ("predict",), ()],
)
@RunIf(quantization=True)
def test_quantization_disable_observers(tmpdir, observer_enabled_stages):
"""Test disabling observers."""
qmodel = RegressionModel()
qcb = QuantizationAwareTraining(observer_enabled_stages=observer_enabled_stages)
trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir)
# Quantize qmodel.
qcb.on_fit_start(trainer, qmodel)
fake_quants = list(module for module in qmodel.modules() if isinstance(module, FakeQuantizeBase))
# Disable some of observers before fitting.
for fake_quant in fake_quants[::2]:
fake_quant.disable_observer()
for stage, on_stage_start, on_stage_end in [
("train", qcb.on_train_start, qcb.on_train_end),
("validate", qcb.on_validation_start, qcb.on_validation_end),
("test", qcb.on_test_start, qcb.on_test_end),
("predict", qcb.on_predict_start, qcb.on_predict_end),
]:
before_stage_observer_enabled = torch.as_tensor(list(map(_get_observer_enabled, fake_quants)))
on_stage_start(trainer, qmodel)
expected_stage_observer_enabled = torch.as_tensor(
before_stage_observer_enabled if stage in observer_enabled_stages else [False] * len(fake_quants)
)
assert torch.equal(
torch.as_tensor(list(map(_get_observer_enabled, fake_quants))), expected_stage_observer_enabled
)
on_stage_end(trainer, qmodel)
assert torch.equal(
torch.as_tensor(list(map(_get_observer_enabled, fake_quants))), before_stage_observer_enabled
)
@RunIf(quantization=True)
def test_quantization_val_test_predict(tmpdir):
"""Test the default quantization aware training not affected by validating, testing and predicting."""
seed_everything(42)
num_features = 16
dm = RegressDataModule(num_features=num_features)
qmodel = RegressionModel()
val_test_predict_qmodel = copy.deepcopy(qmodel)
trainer = Trainer(
callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
val_check_interval=1,
num_sanity_val_steps=1,
max_epochs=4,
)
trainer.fit(val_test_predict_qmodel, datamodule=dm)
trainer.validate(model=val_test_predict_qmodel, verbose=False)
trainer.test(model=val_test_predict_qmodel, verbose=False)
trainer.predict(
model=val_test_predict_qmodel, dataloaders=[torch.utils.data.DataLoader(RandomDataset(num_features, 16))]
)
expected_qmodel = copy.deepcopy(qmodel)
# No validation in ``expected_qmodel`` fitting.
Trainer(
callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=0,
max_epochs=4,
).fit(expected_qmodel, datamodule=dm)
expected_state_dict = expected_qmodel.state_dict()
for key, value in val_test_predict_qmodel.state_dict().items():
expected_value = expected_state_dict[key]
assert torch.allclose(value, expected_value)