Updated quantization imports in PyTorch 1.10 (#9878)

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
theory-in-progress 2021-10-11 16:53:21 +05:30 committed by GitHub
parent 46fa703853
commit 4ecb0d8bc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -20,13 +20,17 @@ import functools
from typing import Any, Callable, Optional, Sequence, Union
import torch
from torch.quantization import QConfig
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _TORCH_GREATER_EQUAL_1_10:
from torch.ao.quantization import QConfig
else:
from torch.quantization import QConfig
def wrap_qat_forward_context(
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None