Disable quantization aware training observers (#8540)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
parent
f8a7f3fde0
commit
cfb2d87765
|
@ -328,13 +328,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- `pytorch_lightning.utilities.grads.grad_norm` now raises an exception if parameter `norm_type <= 0` ([#9765](https://github.com/PyTorchLightning/pytorch-lightning/pull/9765))
|
- `pytorch_lightning.utilities.grads.grad_norm` now raises an exception if parameter `norm_type <= 0` ([#9765](https://github.com/PyTorchLightning/pytorch-lightning/pull/9765))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896))
|
- Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896))
|
||||||
|
|
||||||
|
|
||||||
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
|
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
|
||||||
|
|
||||||
|
|
||||||
|
- Disable quantization aware training observers by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))
|
||||||
|
|
||||||
|
|
||||||
### Deprecated
|
### Deprecated
|
||||||
|
|
||||||
|
@ -409,6 +410,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))
|
- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))
|
||||||
|
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
|
|
||||||
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
|
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
|
||||||
|
@ -611,7 +613,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed `LearningRateMonitor` logging with multiple param groups optimizer with no scheduler ([#10044](https://github.com/PyTorchLightning/pytorch-lightning/pull/10044))
|
- Fixed `LearningRateMonitor` logging with multiple param groups optimizer with no scheduler ([#10044](https://github.com/PyTorchLightning/pytorch-lightning/pull/10044))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
|
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,10 +16,20 @@ Quantization
|
||||||
^^^^^^^^^^^^
|
^^^^^^^^^^^^
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import copy
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Optional, Sequence, Union
|
from typing import Any, Callable, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
|
||||||
|
|
||||||
|
if _TORCH_GREATER_EQUAL_1_8:
|
||||||
|
from torch.quantization import FakeQuantizeBase
|
||||||
|
else:
|
||||||
|
# For torch 1.6 and 1.7.
|
||||||
|
from torch.quantization import FakeQuantize as FakeQuantizeBase
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks.base import Callback
|
from pytorch_lightning.callbacks.base import Callback
|
||||||
|
@ -126,11 +136,25 @@ class QuantizationAwareTraining(Callback):
|
||||||
quantize_on_fit_end: perform the quantization in `on_fit_end`.
|
quantize_on_fit_end: perform the quantization in `on_fit_end`.
|
||||||
Note that once converted, the model cannot be put in training mode again.
|
Note that once converted, the model cannot be put in training mode again.
|
||||||
|
|
||||||
|
observer_enabled_stages: allow fake-quantization modules' observers to do calibration during provided stages:
|
||||||
|
|
||||||
|
- ``'train'``: the observers can do calibration during training.
|
||||||
|
- ``'validate'``: the observers can do calibration during validating.
|
||||||
|
Note that we don't disable observers during the sanity check as the model hasn't been calibrated with
|
||||||
|
training data yet. After the sanity check, the fake-quantization modules are restored to initial states.
|
||||||
|
- ``'test'``: the observers can do calibration during testing.
|
||||||
|
- ``'predict'``: the observers can do calibration during predicting.
|
||||||
|
|
||||||
|
Note that we only handle observers belonging to fake-quantization modules. When ``qconfig`` is a ``str`` and
|
||||||
|
``observer_type`` is ``'histogram'``, the observers won't belong to any fake-quantization modules and will
|
||||||
|
not be controlled by the callback.
|
||||||
|
|
||||||
.. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training
|
.. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training
|
||||||
.. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig
|
.. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig
|
||||||
"""
|
"""
|
||||||
|
|
||||||
OBSERVER_TYPES = ("histogram", "average")
|
OBSERVER_TYPES = ("histogram", "average")
|
||||||
|
OBSERVER_STAGES = ("train", "validate", "test", "predict")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -140,6 +164,7 @@ class QuantizationAwareTraining(Callback):
|
||||||
modules_to_fuse: Optional[Sequence] = None,
|
modules_to_fuse: Optional[Sequence] = None,
|
||||||
input_compatible: bool = True,
|
input_compatible: bool = True,
|
||||||
quantize_on_fit_end: bool = True,
|
quantize_on_fit_end: bool = True,
|
||||||
|
observer_enabled_stages: Sequence[str] = ("train",),
|
||||||
) -> None:
|
) -> None:
|
||||||
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
|
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
|
||||||
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
|
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
|
||||||
|
@ -163,9 +188,20 @@ class QuantizationAwareTraining(Callback):
|
||||||
self.modules_to_fuse = modules_to_fuse
|
self.modules_to_fuse = modules_to_fuse
|
||||||
self._input_compatible = input_compatible
|
self._input_compatible = input_compatible
|
||||||
self._convert_on_fit_end = quantize_on_fit_end
|
self._convert_on_fit_end = quantize_on_fit_end
|
||||||
self._forward_calls = 0
|
|
||||||
|
|
||||||
def _check_feasible_fuse(self, model):
|
observer_enabled_stages = set(observer_enabled_stages)
|
||||||
|
unsupported_stages = observer_enabled_stages - set(self.OBSERVER_STAGES)
|
||||||
|
if unsupported_stages:
|
||||||
|
raise MisconfigurationException(
|
||||||
|
f'Unsupported stages "{tuple(sorted(unsupported_stages))}", allowed are {self.OBSERVER_STAGES}.'
|
||||||
|
)
|
||||||
|
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
|
||||||
|
|
||||||
|
self._forward_calls = 0
|
||||||
|
self._fake_quant_to_initial_state_dict = {}
|
||||||
|
self._last_fake_quant_to_observer_enabled = {}
|
||||||
|
|
||||||
|
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
|
||||||
if not self.modules_to_fuse:
|
if not self.modules_to_fuse:
|
||||||
return False
|
return False
|
||||||
for group in self.modules_to_fuse:
|
for group in self.modules_to_fuse:
|
||||||
|
@ -175,7 +211,20 @@ class QuantizationAwareTraining(Callback):
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def on_fit_start(self, trainer, pl_module):
|
def _collect_observer_enabled(self) -> Dict[FakeQuantizeBase, Tensor]:
|
||||||
|
return {
|
||||||
|
fake_quant: fake_quant.observer_enabled.clone() for fake_quant in self._fake_quant_to_initial_state_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
def _disable_observer(self, pl_module: "pl.LightningModule") -> None:
|
||||||
|
self._last_fake_quant_to_observer_enabled = self._collect_observer_enabled()
|
||||||
|
pl_module.apply(torch.quantization.disable_observer)
|
||||||
|
|
||||||
|
def _restore_last_observer_enabled(self) -> None:
|
||||||
|
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
|
||||||
|
fake_quant.observer_enabled.copy_(observer_enabled)
|
||||||
|
|
||||||
|
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
# QuantStub converts tensors from floating point to quantized
|
# QuantStub converts tensors from floating point to quantized
|
||||||
pl_module.quant = torch.quantization.QuantStub()
|
pl_module.quant = torch.quantization.QuantStub()
|
||||||
# DeQuantStub converts tensors from quantized to floating point
|
# DeQuantStub converts tensors from quantized to floating point
|
||||||
|
@ -209,7 +258,12 @@ class QuantizationAwareTraining(Callback):
|
||||||
# the model that will observe weight and activation tensors during calibration.
|
# the model that will observe weight and activation tensors during calibration.
|
||||||
torch.quantization.prepare_qat(pl_module, inplace=True)
|
torch.quantization.prepare_qat(pl_module, inplace=True)
|
||||||
|
|
||||||
def on_fit_end(self, trainer, pl_module):
|
fake_quants = tuple(module for module in pl_module.modules() if isinstance(module, FakeQuantizeBase))
|
||||||
|
self._fake_quant_to_initial_state_dict = {
|
||||||
|
fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants
|
||||||
|
}
|
||||||
|
|
||||||
|
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
if not self._convert_on_fit_end:
|
if not self._convert_on_fit_end:
|
||||||
pl_module.forward = self.__module_forward
|
pl_module.forward = self.__module_forward
|
||||||
return
|
return
|
||||||
|
@ -224,3 +278,43 @@ class QuantizationAwareTraining(Callback):
|
||||||
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
|
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
|
||||||
else:
|
else:
|
||||||
pl_module.forward = self.__module_forward
|
pl_module.forward = self.__module_forward
|
||||||
|
|
||||||
|
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "train" in self._observer_disabled_stages:
|
||||||
|
self._disable_observer(pl_module)
|
||||||
|
|
||||||
|
def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "train" in self._observer_disabled_stages:
|
||||||
|
self._restore_last_observer_enabled()
|
||||||
|
|
||||||
|
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "validate" in self._observer_disabled_stages and not trainer.sanity_checking:
|
||||||
|
# ``torch.quantization.MovingAveragePerChannelMinMaxObserver`` and ``torch.quantization.HistogramObserver``
|
||||||
|
# need to see at least one batch to infer the shapes of quantization ``scale`` and ``zero_point``. So we
|
||||||
|
# don't disable observers during the sanity check so that they can infer the shapes of quantization
|
||||||
|
# parameters with validation data.
|
||||||
|
self._disable_observer(pl_module)
|
||||||
|
|
||||||
|
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "validate" in self._observer_disabled_stages:
|
||||||
|
if trainer.sanity_checking:
|
||||||
|
for fake_quant, state_dict in self._fake_quant_to_initial_state_dict.items():
|
||||||
|
fake_quant.load_state_dict(state_dict)
|
||||||
|
else:
|
||||||
|
self._restore_last_observer_enabled()
|
||||||
|
|
||||||
|
def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "test" in self._observer_disabled_stages:
|
||||||
|
self._disable_observer(pl_module)
|
||||||
|
|
||||||
|
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "test" in self._observer_disabled_stages:
|
||||||
|
self._restore_last_observer_enabled()
|
||||||
|
|
||||||
|
def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "predict" in self._observer_disabled_stages:
|
||||||
|
self._disable_observer(pl_module)
|
||||||
|
|
||||||
|
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
|
if "predict" in self._observer_disabled_stages:
|
||||||
|
self._restore_last_observer_enabled()
|
||||||
|
|
|
@ -21,11 +21,19 @@ from torchmetrics.functional import mean_absolute_percentage_error as mape
|
||||||
from pytorch_lightning import seed_everything, Trainer
|
from pytorch_lightning import seed_everything, Trainer
|
||||||
from pytorch_lightning.callbacks import QuantizationAwareTraining
|
from pytorch_lightning.callbacks import QuantizationAwareTraining
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
|
||||||
from pytorch_lightning.utilities.memory import get_model_size_mb
|
from pytorch_lightning.utilities.memory import get_model_size_mb
|
||||||
|
from tests.helpers.boring_model import RandomDataset
|
||||||
from tests.helpers.datamodules import RegressDataModule
|
from tests.helpers.datamodules import RegressDataModule
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
from tests.helpers.simple_models import RegressionModel
|
from tests.helpers.simple_models import RegressionModel
|
||||||
|
|
||||||
|
if _TORCH_GREATER_EQUAL_1_8:
|
||||||
|
from torch.quantization import FakeQuantizeBase
|
||||||
|
else:
|
||||||
|
# For torch 1.6 and 1.7.
|
||||||
|
from torch.quantization import FakeQuantize as FakeQuantizeBase
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("observe", ["average", "histogram"])
|
@pytest.mark.parametrize("observe", ["average", "histogram"])
|
||||||
@pytest.mark.parametrize("fuse", [True, False])
|
@pytest.mark.parametrize("fuse", [True, False])
|
||||||
|
@ -45,7 +53,12 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
|
||||||
org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()]))
|
org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()]))
|
||||||
|
|
||||||
fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
|
fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
|
||||||
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
|
qcb = QuantizationAwareTraining(
|
||||||
|
observer_type=observe,
|
||||||
|
modules_to_fuse=fusing_layers,
|
||||||
|
quantize_on_fit_end=convert,
|
||||||
|
observer_enabled_stages=("train", "validate"),
|
||||||
|
)
|
||||||
trainer = Trainer(callbacks=[qcb], **trainer_args)
|
trainer = Trainer(callbacks=[qcb], **trainer_args)
|
||||||
trainer.fit(qmodel, datamodule=dm)
|
trainer.fit(qmodel, datamodule=dm)
|
||||||
|
|
||||||
|
@ -105,6 +118,9 @@ def test_quantization_exceptions(tmpdir):
|
||||||
with pytest.raises(MisconfigurationException, match="Unsupported `collect_quantization`"):
|
with pytest.raises(MisconfigurationException, match="Unsupported `collect_quantization`"):
|
||||||
QuantizationAwareTraining(collect_quantization=1.2)
|
QuantizationAwareTraining(collect_quantization=1.2)
|
||||||
|
|
||||||
|
with pytest.raises(MisconfigurationException, match="Unsupported stages"):
|
||||||
|
QuantizationAwareTraining(observer_enabled_stages=("abc",))
|
||||||
|
|
||||||
fusing_layers = [(f"layers.mlp_{i}", f"layers.NONE-mlp_{i}a") for i in range(3)]
|
fusing_layers = [(f"layers.mlp_{i}", f"layers.NONE-mlp_{i}a") for i in range(3)]
|
||||||
qcb = QuantizationAwareTraining(modules_to_fuse=fusing_layers)
|
qcb = QuantizationAwareTraining(modules_to_fuse=fusing_layers)
|
||||||
trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1)
|
trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1)
|
||||||
|
@ -140,3 +156,91 @@ def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], e
|
||||||
trainer.fit(qmodel, datamodule=dm)
|
trainer.fit(qmodel, datamodule=dm)
|
||||||
|
|
||||||
assert qcb._forward_calls == expected_count
|
assert qcb._forward_calls == expected_count
|
||||||
|
|
||||||
|
|
||||||
|
def _get_observer_enabled(fake_quant: FakeQuantizeBase):
|
||||||
|
# ``torch.quantization.FakeQuantize`` checks ``observer_enabled[0] == 1``.
|
||||||
|
return fake_quant.observer_enabled[0] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"observer_enabled_stages",
|
||||||
|
[("train", "validate", "test", "predict"), ("train",), ("validate",), ("test",), ("predict",), ()],
|
||||||
|
)
|
||||||
|
@RunIf(quantization=True)
|
||||||
|
def test_quantization_disable_observers(tmpdir, observer_enabled_stages):
|
||||||
|
"""Test disabling observers."""
|
||||||
|
qmodel = RegressionModel()
|
||||||
|
qcb = QuantizationAwareTraining(observer_enabled_stages=observer_enabled_stages)
|
||||||
|
trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir)
|
||||||
|
|
||||||
|
# Quantize qmodel.
|
||||||
|
qcb.on_fit_start(trainer, qmodel)
|
||||||
|
fake_quants = list(module for module in qmodel.modules() if isinstance(module, FakeQuantizeBase))
|
||||||
|
# Disable some of observers before fitting.
|
||||||
|
for fake_quant in fake_quants[::2]:
|
||||||
|
fake_quant.disable_observer()
|
||||||
|
|
||||||
|
for stage, on_stage_start, on_stage_end in [
|
||||||
|
("train", qcb.on_train_start, qcb.on_train_end),
|
||||||
|
("validate", qcb.on_validation_start, qcb.on_validation_end),
|
||||||
|
("test", qcb.on_test_start, qcb.on_test_end),
|
||||||
|
("predict", qcb.on_predict_start, qcb.on_predict_end),
|
||||||
|
]:
|
||||||
|
before_stage_observer_enabled = torch.as_tensor(list(map(_get_observer_enabled, fake_quants)))
|
||||||
|
|
||||||
|
on_stage_start(trainer, qmodel)
|
||||||
|
expected_stage_observer_enabled = torch.as_tensor(
|
||||||
|
before_stage_observer_enabled if stage in observer_enabled_stages else [False] * len(fake_quants)
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
torch.as_tensor(list(map(_get_observer_enabled, fake_quants))), expected_stage_observer_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
on_stage_end(trainer, qmodel)
|
||||||
|
assert torch.equal(
|
||||||
|
torch.as_tensor(list(map(_get_observer_enabled, fake_quants))), before_stage_observer_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@RunIf(quantization=True)
|
||||||
|
def test_quantization_val_test_predict(tmpdir):
|
||||||
|
"""Test the default quantization aware training not affected by validating, testing and predicting."""
|
||||||
|
seed_everything(42)
|
||||||
|
num_features = 16
|
||||||
|
dm = RegressDataModule(num_features=num_features)
|
||||||
|
qmodel = RegressionModel()
|
||||||
|
|
||||||
|
val_test_predict_qmodel = copy.deepcopy(qmodel)
|
||||||
|
trainer = Trainer(
|
||||||
|
callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=1,
|
||||||
|
limit_val_batches=1,
|
||||||
|
limit_test_batches=1,
|
||||||
|
limit_predict_batches=1,
|
||||||
|
val_check_interval=1,
|
||||||
|
num_sanity_val_steps=1,
|
||||||
|
max_epochs=4,
|
||||||
|
)
|
||||||
|
trainer.fit(val_test_predict_qmodel, datamodule=dm)
|
||||||
|
trainer.validate(model=val_test_predict_qmodel, verbose=False)
|
||||||
|
trainer.test(model=val_test_predict_qmodel, verbose=False)
|
||||||
|
trainer.predict(
|
||||||
|
model=val_test_predict_qmodel, dataloaders=[torch.utils.data.DataLoader(RandomDataset(num_features, 16))]
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_qmodel = copy.deepcopy(qmodel)
|
||||||
|
# No validation in ``expected_qmodel`` fitting.
|
||||||
|
Trainer(
|
||||||
|
callbacks=[QuantizationAwareTraining(quantize_on_fit_end=False)],
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=1,
|
||||||
|
limit_val_batches=0,
|
||||||
|
max_epochs=4,
|
||||||
|
).fit(expected_qmodel, datamodule=dm)
|
||||||
|
|
||||||
|
expected_state_dict = expected_qmodel.state_dict()
|
||||||
|
for key, value in val_test_predict_qmodel.state_dict().items():
|
||||||
|
expected_value = expected_state_dict[key]
|
||||||
|
assert torch.allclose(value, expected_value)
|
||||||
|
|
Loading…
Reference in New Issue