Fix mypy errors attributed to `pytorch_lightning.callbacks.quantization` (#13782)
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
0b1a29b5eb
commit
0ca3b5aa1b
|
@ -50,7 +50,6 @@ warn_no_return = "False"
|
||||||
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
|
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
|
||||||
module = [
|
module = [
|
||||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||||
"pytorch_lightning.callbacks.quantization",
|
|
||||||
"pytorch_lightning.core.datamodule",
|
"pytorch_lightning.core.datamodule",
|
||||||
"pytorch_lightning.demos.boring_classes",
|
"pytorch_lightning.demos.boring_classes",
|
||||||
"pytorch_lightning.demos.mnist_datamodule",
|
"pytorch_lightning.demos.mnist_datamodule",
|
||||||
|
|
|
@ -41,25 +41,28 @@ else:
|
||||||
|
|
||||||
|
|
||||||
def wrap_qat_forward_context(
|
def wrap_qat_forward_context(
|
||||||
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
|
quant_cb: "QuantizationAwareTraining",
|
||||||
|
model: "pl.LightningModule",
|
||||||
|
func: Callable,
|
||||||
|
trigger_condition: Optional[Union[Callable, int]] = None,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
|
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
|
||||||
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
|
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
|
||||||
training all the time."""
|
training all the time."""
|
||||||
# todo: consider using registering hook before/after forward
|
# todo: consider using registering hook before/after forward
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(data) -> Any:
|
def wrapper(data: Any) -> Any:
|
||||||
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
|
_is_func_true = callable(trigger_condition) and trigger_condition(model.trainer)
|
||||||
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
|
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
|
||||||
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
|
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
|
||||||
# apply custom trigger
|
# apply custom trigger
|
||||||
if _quant_run:
|
if _quant_run:
|
||||||
quant_cb._forward_calls += 1
|
quant_cb._forward_calls += 1
|
||||||
data = model.quant(data)
|
data = model.quant(data) # type: ignore[operator]
|
||||||
data = func(data)
|
data = func(data)
|
||||||
# apply custom trigger
|
# apply custom trigger
|
||||||
if _quant_run:
|
if _quant_run:
|
||||||
data = model.dequant(data)
|
data = model.dequant(data) # type: ignore[operator]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -
|
||||||
compatibility."""
|
compatibility."""
|
||||||
# todo: consider using registering hook before/after forward
|
# todo: consider using registering hook before/after forward
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(data) -> Any:
|
def wrapper(data: Any) -> Any:
|
||||||
data = model.quant(data)
|
data = model.quant(data) # type: ignore[operator]
|
||||||
data = func(data)
|
data = func(data)
|
||||||
data = model.dequant(data)
|
data = model.dequant(data) # type: ignore[operator]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -181,7 +184,9 @@ class QuantizationAwareTraining(Callback):
|
||||||
)
|
)
|
||||||
self._observer_type = observer_type
|
self._observer_type = observer_type
|
||||||
|
|
||||||
if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
|
if collect_quantization is not None and not (
|
||||||
|
isinstance(collect_quantization, int) or callable(collect_quantization)
|
||||||
|
):
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
|
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
|
||||||
)
|
)
|
||||||
|
@ -200,8 +205,8 @@ class QuantizationAwareTraining(Callback):
|
||||||
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
|
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
|
||||||
|
|
||||||
self._forward_calls = 0
|
self._forward_calls = 0
|
||||||
self._fake_quant_to_initial_state_dict = {}
|
self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {}
|
||||||
self._last_fake_quant_to_observer_enabled = {}
|
self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {}
|
||||||
self._module_prepared = False
|
self._module_prepared = False
|
||||||
|
|
||||||
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
|
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
|
||||||
|
@ -227,7 +232,7 @@ class QuantizationAwareTraining(Callback):
|
||||||
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
|
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
|
||||||
fake_quant.observer_enabled.copy_(observer_enabled)
|
fake_quant.observer_enabled.copy_(observer_enabled)
|
||||||
|
|
||||||
def _prepare_model(self, model: torch.nn.Module) -> None:
|
def _prepare_model(self, model: "pl.LightningModule") -> None:
|
||||||
if self._module_prepared:
|
if self._module_prepared:
|
||||||
return
|
return
|
||||||
# QuantStub converts tensors from floating point to quantized
|
# QuantStub converts tensors from floating point to quantized
|
||||||
|
@ -237,7 +242,7 @@ class QuantizationAwareTraining(Callback):
|
||||||
# manually specify where tensors will be converted from quantized
|
# manually specify where tensors will be converted from quantized
|
||||||
# to floating point in the quantized model
|
# to floating point in the quantized model
|
||||||
self.__module_forward = model.forward
|
self.__module_forward = model.forward
|
||||||
model.forward = wrap_qat_forward_context(
|
model.forward = wrap_qat_forward_context( # type: ignore [assignment]
|
||||||
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
|
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -247,7 +252,7 @@ class QuantizationAwareTraining(Callback):
|
||||||
if self._observer_type == "histogram":
|
if self._observer_type == "histogram":
|
||||||
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
|
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
|
||||||
elif self._observer_type == "average":
|
elif self._observer_type == "average":
|
||||||
extra_kwargs = {}
|
extra_kwargs: Dict[str, Optional[int]] = {}
|
||||||
if _TORCH_GREATER_EQUAL_1_12:
|
if _TORCH_GREATER_EQUAL_1_12:
|
||||||
extra_kwargs["version"] = 0
|
extra_kwargs["version"] = 0
|
||||||
# version=None corresponds to using FakeQuantize rather than
|
# version=None corresponds to using FakeQuantize rather than
|
||||||
|
@ -258,7 +263,7 @@ class QuantizationAwareTraining(Callback):
|
||||||
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
|
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
|
||||||
|
|
||||||
elif isinstance(self._qconfig, QConfig):
|
elif isinstance(self._qconfig, QConfig):
|
||||||
model.qconfig = self._qconfig
|
model.qconfig = self._qconfig # type: ignore [assignment]
|
||||||
|
|
||||||
if self._check_feasible_fuse(model):
|
if self._check_feasible_fuse(model):
|
||||||
fuse_modules(model, self._modules_to_fuse, inplace=True)
|
fuse_modules(model, self._modules_to_fuse, inplace=True)
|
||||||
|
@ -273,12 +278,12 @@ class QuantizationAwareTraining(Callback):
|
||||||
}
|
}
|
||||||
self._module_prepared = True
|
self._module_prepared = True
|
||||||
|
|
||||||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
|
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
self._prepare_model(pl_module)
|
self._prepare_model(pl_module)
|
||||||
|
|
||||||
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
if not self._convert_on_fit_end:
|
if not self._convert_on_fit_end:
|
||||||
pl_module.forward = self.__module_forward
|
pl_module.forward = self.__module_forward # type: ignore [assignment]
|
||||||
return
|
return
|
||||||
pl_module.eval()
|
pl_module.eval()
|
||||||
# Convert the observed model to a quantized model. This does several things:
|
# Convert the observed model to a quantized model. This does several things:
|
||||||
|
@ -288,9 +293,12 @@ class QuantizationAwareTraining(Callback):
|
||||||
torch.quantization.convert(pl_module, inplace=True)
|
torch.quantization.convert(pl_module, inplace=True)
|
||||||
# check we shall preserve wrapper
|
# check we shall preserve wrapper
|
||||||
if self._input_compatible:
|
if self._input_compatible:
|
||||||
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
|
pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment]
|
||||||
|
model=pl_module,
|
||||||
|
func=self.__module_forward,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pl_module.forward = self.__module_forward
|
pl_module.forward = self.__module_forward # type: ignore [assignment]
|
||||||
|
|
||||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||||
if "train" in self._observer_disabled_stages:
|
if "train" in self._observer_disabled_stages:
|
||||||
|
@ -336,7 +344,7 @@ class QuantizationAwareTraining(Callback):
|
||||||
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
|
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
|
||||||
return {n: getattr(self, n) for n in keys}
|
return {n: getattr(self, n) for n in keys}
|
||||||
|
|
||||||
def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
|
def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None:
|
||||||
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
|
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
|
||||||
|
|
||||||
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
|
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
|
||||||
|
|
|
@ -245,7 +245,7 @@ class CheckpointConnector:
|
||||||
if state:
|
if state:
|
||||||
# The Quantization callbacks have a special method that must be called before restoring the weights
|
# The Quantization callbacks have a special method that must be called before restoring the weights
|
||||||
# of the model
|
# of the model
|
||||||
callback._load_before_model(self.trainer.model, deepcopy(state))
|
callback._load_before_model(self.trainer.lightning_module, deepcopy(state))
|
||||||
|
|
||||||
def restore_callbacks(self) -> None:
|
def restore_callbacks(self) -> None:
|
||||||
"""Restores all callbacks from the pre-loaded checkpoint."""
|
"""Restores all callbacks from the pre-loaded checkpoint."""
|
||||||
|
|
Loading…
Reference in New Issue