Fix mypy errors attributed to `pytorch_lightning.callbacks.quantization` (#13782)
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
0b1a29b5eb
commit
0ca3b5aa1b
|
@ -50,7 +50,6 @@ warn_no_return = "False"
|
|||
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
|
||||
module = [
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.callbacks.quantization",
|
||||
"pytorch_lightning.core.datamodule",
|
||||
"pytorch_lightning.demos.boring_classes",
|
||||
"pytorch_lightning.demos.mnist_datamodule",
|
||||
|
|
|
@ -41,25 +41,28 @@ else:
|
|||
|
||||
|
||||
def wrap_qat_forward_context(
|
||||
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
|
||||
quant_cb: "QuantizationAwareTraining",
|
||||
model: "pl.LightningModule",
|
||||
func: Callable,
|
||||
trigger_condition: Optional[Union[Callable, int]] = None,
|
||||
) -> Callable:
|
||||
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
|
||||
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
|
||||
training all the time."""
|
||||
# todo: consider using registering hook before/after forward
|
||||
@functools.wraps(func)
|
||||
def wrapper(data) -> Any:
|
||||
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
|
||||
def wrapper(data: Any) -> Any:
|
||||
_is_func_true = callable(trigger_condition) and trigger_condition(model.trainer)
|
||||
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
|
||||
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
|
||||
# apply custom trigger
|
||||
if _quant_run:
|
||||
quant_cb._forward_calls += 1
|
||||
data = model.quant(data)
|
||||
data = model.quant(data) # type: ignore[operator]
|
||||
data = func(data)
|
||||
# apply custom trigger
|
||||
if _quant_run:
|
||||
data = model.dequant(data)
|
||||
data = model.dequant(data) # type: ignore[operator]
|
||||
return data
|
||||
|
||||
return wrapper
|
||||
|
@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -
|
|||
compatibility."""
|
||||
# todo: consider using registering hook before/after forward
|
||||
@functools.wraps(func)
|
||||
def wrapper(data) -> Any:
|
||||
data = model.quant(data)
|
||||
def wrapper(data: Any) -> Any:
|
||||
data = model.quant(data) # type: ignore[operator]
|
||||
data = func(data)
|
||||
data = model.dequant(data)
|
||||
data = model.dequant(data) # type: ignore[operator]
|
||||
return data
|
||||
|
||||
return wrapper
|
||||
|
@ -181,7 +184,9 @@ class QuantizationAwareTraining(Callback):
|
|||
)
|
||||
self._observer_type = observer_type
|
||||
|
||||
if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
|
||||
if collect_quantization is not None and not (
|
||||
isinstance(collect_quantization, int) or callable(collect_quantization)
|
||||
):
|
||||
raise MisconfigurationException(
|
||||
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
|
||||
)
|
||||
|
@ -200,8 +205,8 @@ class QuantizationAwareTraining(Callback):
|
|||
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
|
||||
|
||||
self._forward_calls = 0
|
||||
self._fake_quant_to_initial_state_dict = {}
|
||||
self._last_fake_quant_to_observer_enabled = {}
|
||||
self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {}
|
||||
self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {}
|
||||
self._module_prepared = False
|
||||
|
||||
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
|
||||
|
@ -227,7 +232,7 @@ class QuantizationAwareTraining(Callback):
|
|||
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
|
||||
fake_quant.observer_enabled.copy_(observer_enabled)
|
||||
|
||||
def _prepare_model(self, model: torch.nn.Module) -> None:
|
||||
def _prepare_model(self, model: "pl.LightningModule") -> None:
|
||||
if self._module_prepared:
|
||||
return
|
||||
# QuantStub converts tensors from floating point to quantized
|
||||
|
@ -237,7 +242,7 @@ class QuantizationAwareTraining(Callback):
|
|||
# manually specify where tensors will be converted from quantized
|
||||
# to floating point in the quantized model
|
||||
self.__module_forward = model.forward
|
||||
model.forward = wrap_qat_forward_context(
|
||||
model.forward = wrap_qat_forward_context( # type: ignore [assignment]
|
||||
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
|
||||
)
|
||||
|
||||
|
@ -247,7 +252,7 @@ class QuantizationAwareTraining(Callback):
|
|||
if self._observer_type == "histogram":
|
||||
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
|
||||
elif self._observer_type == "average":
|
||||
extra_kwargs = {}
|
||||
extra_kwargs: Dict[str, Optional[int]] = {}
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
extra_kwargs["version"] = 0
|
||||
# version=None corresponds to using FakeQuantize rather than
|
||||
|
@ -258,7 +263,7 @@ class QuantizationAwareTraining(Callback):
|
|||
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
|
||||
|
||||
elif isinstance(self._qconfig, QConfig):
|
||||
model.qconfig = self._qconfig
|
||||
model.qconfig = self._qconfig # type: ignore [assignment]
|
||||
|
||||
if self._check_feasible_fuse(model):
|
||||
fuse_modules(model, self._modules_to_fuse, inplace=True)
|
||||
|
@ -273,12 +278,12 @@ class QuantizationAwareTraining(Callback):
|
|||
}
|
||||
self._module_prepared = True
|
||||
|
||||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
|
||||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._prepare_model(pl_module)
|
||||
|
||||
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self._convert_on_fit_end:
|
||||
pl_module.forward = self.__module_forward
|
||||
pl_module.forward = self.__module_forward # type: ignore [assignment]
|
||||
return
|
||||
pl_module.eval()
|
||||
# Convert the observed model to a quantized model. This does several things:
|
||||
|
@ -288,9 +293,12 @@ class QuantizationAwareTraining(Callback):
|
|||
torch.quantization.convert(pl_module, inplace=True)
|
||||
# check we shall preserve wrapper
|
||||
if self._input_compatible:
|
||||
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
|
||||
pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment]
|
||||
model=pl_module,
|
||||
func=self.__module_forward,
|
||||
)
|
||||
else:
|
||||
pl_module.forward = self.__module_forward
|
||||
pl_module.forward = self.__module_forward # type: ignore [assignment]
|
||||
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if "train" in self._observer_disabled_stages:
|
||||
|
@ -336,7 +344,7 @@ class QuantizationAwareTraining(Callback):
|
|||
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
|
||||
return {n: getattr(self, n) for n in keys}
|
||||
|
||||
def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
|
||||
def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None:
|
||||
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
|
||||
|
||||
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
|
||||
|
|
|
@ -245,7 +245,7 @@ class CheckpointConnector:
|
|||
if state:
|
||||
# The Quantization callbacks have a special method that must be called before restoring the weights
|
||||
# of the model
|
||||
callback._load_before_model(self.trainer.model, deepcopy(state))
|
||||
callback._load_before_model(self.trainer.lightning_module, deepcopy(state))
|
||||
|
||||
def restore_callbacks(self) -> None:
|
||||
"""Restores all callbacks from the pre-loaded checkpoint."""
|
||||
|
|
Loading…
Reference in New Issue