Fix test_quantization with Pytorch 1.10 (#9808)
This commit is contained in:
parent
8a8ecb8d01
commit
0561fd6925
|
@ -24,6 +24,7 @@ from torch.quantization import QConfig
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.callbacks.base import Callback
|
from pytorch_lightning.callbacks.base import Callback
|
||||||
|
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
|
|
||||||
|
@ -188,7 +189,12 @@ class QuantizationAwareTraining(Callback):
|
||||||
if self._observer_type == "histogram":
|
if self._observer_type == "histogram":
|
||||||
pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
|
pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
|
||||||
elif self._observer_type == "average":
|
elif self._observer_type == "average":
|
||||||
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig)
|
# version=None corresponds to using FakeQuantize rather than
|
||||||
|
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
|
||||||
|
# details in https://github.com/pytorch/pytorch/issues/64564
|
||||||
|
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
|
||||||
|
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
|
||||||
|
|
||||||
elif isinstance(self._qconfig, QConfig):
|
elif isinstance(self._qconfig, QConfig):
|
||||||
pl_module.qconfig = self._qconfig
|
pl_module.qconfig = self._qconfig
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ def _compare_version(package: str, op, version) -> bool:
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# this is mock by sphinx, so it shall return True ro generate all summaries
|
# this is mock by sphinx, so it shall return True ro generate all summaries
|
||||||
return True
|
return True
|
||||||
return op(pkg_version, Version(version))
|
return op(Version(pkg_version.base_version), Version(version))
|
||||||
|
|
||||||
|
|
||||||
_IS_WINDOWS = platform.system() == "Windows"
|
_IS_WINDOWS = platform.system() == "Windows"
|
||||||
|
|
Loading…
Reference in New Issue