Updated quantization imports in PyTorch 1.10 (#9878)
Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
46fa703853
commit
4ecb0d8bc9
|
@ -20,13 +20,17 @@ import functools
|
|||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
from torch.quantization import QConfig
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_10:
|
||||
from torch.ao.quantization import QConfig
|
||||
else:
|
||||
from torch.quantization import QConfig
|
||||
|
||||
|
||||
def wrap_qat_forward_context(
|
||||
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
|
||||
|
|
Loading…
Reference in New Issue