Fix test_quantization with Pytorch 1.10 (#9808)
This commit is contained in:
parent
8a8ecb8d01
commit
0561fd6925
|
@ -24,6 +24,7 @@ from torch.quantization import QConfig
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -188,7 +189,12 @@ class QuantizationAwareTraining(Callback):
|
|||
if self._observer_type == "histogram":
|
||||
pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
|
||||
elif self._observer_type == "average":
|
||||
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig)
|
||||
# version=None corresponds to using FakeQuantize rather than
|
||||
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
|
||||
# details in https://github.com/pytorch/pytorch/issues/64564
|
||||
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
|
||||
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
|
||||
|
||||
elif isinstance(self._qconfig, QConfig):
|
||||
pl_module.qconfig = self._qconfig
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ def _compare_version(package: str, op, version) -> bool:
|
|||
except TypeError:
|
||||
# this is mock by sphinx, so it shall return True ro generate all summaries
|
||||
return True
|
||||
return op(pkg_version, Version(version))
|
||||
return op(Version(pkg_version.base_version), Version(version))
|
||||
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
|
|
Loading…
Reference in New Issue