Fix mypy errors attributed to `pytorch_lightning.callbacks.quantization` (#13782)

Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Krishna Kalyan 2022-08-20 00:39:16 +01:00 committed by GitHub
parent 0b1a29b5eb
commit 0ca3b5aa1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 22 deletions

View File

@ -50,7 +50,6 @@ warn_no_return = "False"
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",

View File

@ -41,25 +41,28 @@ else:
def wrap_qat_forward_context(
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
quant_cb: "QuantizationAwareTraining",
model: "pl.LightningModule",
func: Callable,
trigger_condition: Optional[Union[Callable, int]] = None,
) -> Callable:
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
training all the time."""
# todo: consider using registering hook before/after forward
@functools.wraps(func)
def wrapper(data) -> Any:
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
def wrapper(data: Any) -> Any:
_is_func_true = callable(trigger_condition) and trigger_condition(model.trainer)
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
# apply custom trigger
if _quant_run:
quant_cb._forward_calls += 1
data = model.quant(data)
data = model.quant(data) # type: ignore[operator]
data = func(data)
# apply custom trigger
if _quant_run:
data = model.dequant(data)
data = model.dequant(data) # type: ignore[operator]
return data
return wrapper
@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -
compatibility."""
# todo: consider using registering hook before/after forward
@functools.wraps(func)
def wrapper(data) -> Any:
data = model.quant(data)
def wrapper(data: Any) -> Any:
data = model.quant(data) # type: ignore[operator]
data = func(data)
data = model.dequant(data)
data = model.dequant(data) # type: ignore[operator]
return data
return wrapper
@ -181,7 +184,9 @@ class QuantizationAwareTraining(Callback):
)
self._observer_type = observer_type
if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
if collect_quantization is not None and not (
isinstance(collect_quantization, int) or callable(collect_quantization)
):
raise MisconfigurationException(
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
)
@ -200,8 +205,8 @@ class QuantizationAwareTraining(Callback):
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
self._forward_calls = 0
self._fake_quant_to_initial_state_dict = {}
self._last_fake_quant_to_observer_enabled = {}
self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {}
self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {}
self._module_prepared = False
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
@ -227,7 +232,7 @@ class QuantizationAwareTraining(Callback):
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
fake_quant.observer_enabled.copy_(observer_enabled)
def _prepare_model(self, model: torch.nn.Module) -> None:
def _prepare_model(self, model: "pl.LightningModule") -> None:
if self._module_prepared:
return
# QuantStub converts tensors from floating point to quantized
@ -237,7 +242,7 @@ class QuantizationAwareTraining(Callback):
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
self.__module_forward = model.forward
model.forward = wrap_qat_forward_context(
model.forward = wrap_qat_forward_context( # type: ignore [assignment]
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
)
@ -247,7 +252,7 @@ class QuantizationAwareTraining(Callback):
if self._observer_type == "histogram":
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
elif self._observer_type == "average":
extra_kwargs = {}
extra_kwargs: Dict[str, Optional[int]] = {}
if _TORCH_GREATER_EQUAL_1_12:
extra_kwargs["version"] = 0
# version=None corresponds to using FakeQuantize rather than
@ -258,7 +263,7 @@ class QuantizationAwareTraining(Callback):
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
elif isinstance(self._qconfig, QConfig):
model.qconfig = self._qconfig
model.qconfig = self._qconfig # type: ignore [assignment]
if self._check_feasible_fuse(model):
fuse_modules(model, self._modules_to_fuse, inplace=True)
@ -273,12 +278,12 @@ class QuantizationAwareTraining(Callback):
}
self._module_prepared = True
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._prepare_model(pl_module)
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._convert_on_fit_end:
pl_module.forward = self.__module_forward
pl_module.forward = self.__module_forward # type: ignore [assignment]
return
pl_module.eval()
# Convert the observed model to a quantized model. This does several things:
@ -288,9 +293,12 @@ class QuantizationAwareTraining(Callback):
torch.quantization.convert(pl_module, inplace=True)
# check we shall preserve wrapper
if self._input_compatible:
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment]
model=pl_module,
func=self.__module_forward,
)
else:
pl_module.forward = self.__module_forward
pl_module.forward = self.__module_forward # type: ignore [assignment]
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "train" in self._observer_disabled_stages:
@ -336,7 +344,7 @@ class QuantizationAwareTraining(Callback):
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
return {n: getattr(self, n) for n in keys}
def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None:
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called

View File

@ -245,7 +245,7 @@ class CheckpointConnector:
if state:
# The Quantization callbacks have a special method that must be called before restoring the weights
# of the model
callback._load_before_model(self.trainer.model, deepcopy(state))
callback._load_before_model(self.trainer.lightning_module, deepcopy(state))
def restore_callbacks(self) -> None:
"""Restores all callbacks from the pre-loaded checkpoint."""