From 0561fd6925a0f4aa6d119c5eb451dfb6e598f83f Mon Sep 17 00:00:00 2001 From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> Date: Thu, 7 Oct 2021 00:54:06 -0700 Subject: [PATCH] Fix test_quantization with Pytorch 1.10 (#9808) --- pytorch_lightning/callbacks/quantization.py | 8 +++++++- pytorch_lightning/utilities/imports.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index 140c774dde..da85b100ca 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -24,6 +24,7 @@ from torch.quantization import QConfig import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -188,7 +189,12 @@ class QuantizationAwareTraining(Callback): if self._observer_type == "histogram": pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig) elif self._observer_type == "average": - pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig) + # version=None corresponds to using FakeQuantize rather than + # FusedMovingAvgObsFakeQuantize which was introduced in PT1.10 + # details in https://github.com/pytorch/pytorch/issues/64564 + extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {} + pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs) + elif isinstance(self._qconfig, QConfig): pl_module.qconfig = self._qconfig diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index aa59bc4537..870106f726 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -62,7 +62,7 @@ def _compare_version(package: str, op, version) -> bool: except TypeError: # this is mock by sphinx, so it shall return True ro generate all summaries return True - return op(pkg_version, Version(version)) + return op(Version(pkg_version.base_version), Version(version)) _IS_WINDOWS = platform.system() == "Windows"