Fix mypy errors attributed to `pytorch_lightning.callbacks.quantization` (#13782)

Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Krishna Kalyan 2022-08-20 00:39:16 +01:00 committed by GitHub
parent 0b1a29b5eb
commit 0ca3b5aa1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 22 deletions

View File

@ -50,7 +50,6 @@ warn_no_return = "False"
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [ module = [
"pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule", "pytorch_lightning.core.datamodule",
"pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.demos.mnist_datamodule",

View File

@ -41,25 +41,28 @@ else:
def wrap_qat_forward_context( def wrap_qat_forward_context(
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None quant_cb: "QuantizationAwareTraining",
model: "pl.LightningModule",
func: Callable,
trigger_condition: Optional[Union[Callable, int]] = None,
) -> Callable: ) -> Callable:
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
training all the time.""" training all the time."""
# todo: consider using registering hook before/after forward # todo: consider using registering hook before/after forward
@functools.wraps(func) @functools.wraps(func)
def wrapper(data) -> Any: def wrapper(data: Any) -> Any:
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) _is_func_true = callable(trigger_condition) and trigger_condition(model.trainer)
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
_quant_run = trigger_condition is None or _is_func_true or _is_count_true _quant_run = trigger_condition is None or _is_func_true or _is_count_true
# apply custom trigger # apply custom trigger
if _quant_run: if _quant_run:
quant_cb._forward_calls += 1 quant_cb._forward_calls += 1
data = model.quant(data) data = model.quant(data) # type: ignore[operator]
data = func(data) data = func(data)
# apply custom trigger # apply custom trigger
if _quant_run: if _quant_run:
data = model.dequant(data) data = model.dequant(data) # type: ignore[operator]
return data return data
return wrapper return wrapper
@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -
compatibility.""" compatibility."""
# todo: consider using registering hook before/after forward # todo: consider using registering hook before/after forward
@functools.wraps(func) @functools.wraps(func)
def wrapper(data) -> Any: def wrapper(data: Any) -> Any:
data = model.quant(data) data = model.quant(data) # type: ignore[operator]
data = func(data) data = func(data)
data = model.dequant(data) data = model.dequant(data) # type: ignore[operator]
return data return data
return wrapper return wrapper
@ -181,7 +184,9 @@ class QuantizationAwareTraining(Callback):
) )
self._observer_type = observer_type self._observer_type = observer_type
if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)): if collect_quantization is not None and not (
isinstance(collect_quantization, int) or callable(collect_quantization)
):
raise MisconfigurationException( raise MisconfigurationException(
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.' f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
) )
@ -200,8 +205,8 @@ class QuantizationAwareTraining(Callback):
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
self._forward_calls = 0 self._forward_calls = 0
self._fake_quant_to_initial_state_dict = {} self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {}
self._last_fake_quant_to_observer_enabled = {} self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {}
self._module_prepared = False self._module_prepared = False
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool: def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
@ -227,7 +232,7 @@ class QuantizationAwareTraining(Callback):
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items(): for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
fake_quant.observer_enabled.copy_(observer_enabled) fake_quant.observer_enabled.copy_(observer_enabled)
def _prepare_model(self, model: torch.nn.Module) -> None: def _prepare_model(self, model: "pl.LightningModule") -> None:
if self._module_prepared: if self._module_prepared:
return return
# QuantStub converts tensors from floating point to quantized # QuantStub converts tensors from floating point to quantized
@ -237,7 +242,7 @@ class QuantizationAwareTraining(Callback):
# manually specify where tensors will be converted from quantized # manually specify where tensors will be converted from quantized
# to floating point in the quantized model # to floating point in the quantized model
self.__module_forward = model.forward self.__module_forward = model.forward
model.forward = wrap_qat_forward_context( model.forward = wrap_qat_forward_context( # type: ignore [assignment]
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
) )
@ -247,7 +252,7 @@ class QuantizationAwareTraining(Callback):
if self._observer_type == "histogram": if self._observer_type == "histogram":
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig) model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
elif self._observer_type == "average": elif self._observer_type == "average":
extra_kwargs = {} extra_kwargs: Dict[str, Optional[int]] = {}
if _TORCH_GREATER_EQUAL_1_12: if _TORCH_GREATER_EQUAL_1_12:
extra_kwargs["version"] = 0 extra_kwargs["version"] = 0
# version=None corresponds to using FakeQuantize rather than # version=None corresponds to using FakeQuantize rather than
@ -258,7 +263,7 @@ class QuantizationAwareTraining(Callback):
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs) model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
elif isinstance(self._qconfig, QConfig): elif isinstance(self._qconfig, QConfig):
model.qconfig = self._qconfig model.qconfig = self._qconfig # type: ignore [assignment]
if self._check_feasible_fuse(model): if self._check_feasible_fuse(model):
fuse_modules(model, self._modules_to_fuse, inplace=True) fuse_modules(model, self._modules_to_fuse, inplace=True)
@ -273,12 +278,12 @@ class QuantizationAwareTraining(Callback):
} }
self._module_prepared = True self._module_prepared = True
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._prepare_model(pl_module) self._prepare_model(pl_module)
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._convert_on_fit_end: if not self._convert_on_fit_end:
pl_module.forward = self.__module_forward pl_module.forward = self.__module_forward # type: ignore [assignment]
return return
pl_module.eval() pl_module.eval()
# Convert the observed model to a quantized model. This does several things: # Convert the observed model to a quantized model. This does several things:
@ -288,9 +293,12 @@ class QuantizationAwareTraining(Callback):
torch.quantization.convert(pl_module, inplace=True) torch.quantization.convert(pl_module, inplace=True)
# check we shall preserve wrapper # check we shall preserve wrapper
if self._input_compatible: if self._input_compatible:
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment]
model=pl_module,
func=self.__module_forward,
)
else: else:
pl_module.forward = self.__module_forward pl_module.forward = self.__module_forward # type: ignore [assignment]
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "train" in self._observer_disabled_stages: if "train" in self._observer_disabled_stages:
@ -336,7 +344,7 @@ class QuantizationAwareTraining(Callback):
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"} keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
return {n: getattr(self, n) for n in keys} return {n: getattr(self, n) for n in keys}
def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None: def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None:
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded. """Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called

View File

@ -245,7 +245,7 @@ class CheckpointConnector:
if state: if state:
# The Quantization callbacks have a special method that must be called before restoring the weights # The Quantization callbacks have a special method that must be called before restoring the weights
# of the model # of the model
callback._load_before_model(self.trainer.model, deepcopy(state)) callback._load_before_model(self.trainer.lightning_module, deepcopy(state))
def restore_callbacks(self) -> None: def restore_callbacks(self) -> None:
"""Restores all callbacks from the pre-loaded checkpoint.""" """Restores all callbacks from the pre-loaded checkpoint."""