diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index da85b100ca..6e805b79a9 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -20,13 +20,17 @@ import functools from typing import Any, Callable, Optional, Sequence, Union import torch -from torch.quantization import QConfig import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException +if _TORCH_GREATER_EQUAL_1_10: + from torch.ao.quantization import QConfig +else: + from torch.quantization import QConfig + def wrap_qat_forward_context( quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None