Fix test_quantization with Pytorch 1.10 (#9808)

This commit is contained in:
Danielle Pintz 2021-10-07 00:54:06 -07:00 committed by GitHub
parent 8a8ecb8d01
commit 0561fd6925
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 2 deletions

View File

@ -24,6 +24,7 @@ from torch.quantization import QConfig
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -188,7 +189,12 @@ class QuantizationAwareTraining(Callback):
if self._observer_type == "histogram": if self._observer_type == "histogram":
pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig) pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
elif self._observer_type == "average": elif self._observer_type == "average":
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig) # version=None corresponds to using FakeQuantize rather than
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
# details in https://github.com/pytorch/pytorch/issues/64564
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
elif isinstance(self._qconfig, QConfig): elif isinstance(self._qconfig, QConfig):
pl_module.qconfig = self._qconfig pl_module.qconfig = self._qconfig

View File

@ -62,7 +62,7 @@ def _compare_version(package: str, op, version) -> bool:
except TypeError: except TypeError:
# this is mock by sphinx, so it shall return True ro generate all summaries # this is mock by sphinx, so it shall return True ro generate all summaries
return True return True
return op(pkg_version, Version(version)) return op(Version(pkg_version.base_version), Version(version))
_IS_WINDOWS = platform.system() == "Windows" _IS_WINDOWS = platform.system() == "Windows"