diff --git a/pyproject.toml b/pyproject.toml index 45f65b4c44..14e49ffb24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ warn_no_return = "False" # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' module = [ "pytorch_lightning.callbacks.progress.rich_progress", - "pytorch_lightning.callbacks.quantization", "pytorch_lightning.core.datamodule", "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index af983ef101..d89bed0394 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -41,25 +41,28 @@ else: def wrap_qat_forward_context( - quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None + quant_cb: "QuantizationAwareTraining", + model: "pl.LightningModule", + func: Callable, + trigger_condition: Optional[Union[Callable, int]] = None, ) -> Callable: """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the training all the time.""" # todo: consider using registering hook before/after forward @functools.wraps(func) - def wrapper(data) -> Any: - _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) + def wrapper(data: Any) -> Any: + _is_func_true = callable(trigger_condition) and trigger_condition(model.trainer) _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition _quant_run = trigger_condition is None or _is_func_true or _is_count_true # apply custom trigger if _quant_run: quant_cb._forward_calls += 1 - data = model.quant(data) + data = model.quant(data) # type: ignore[operator] data = func(data) # apply custom trigger if _quant_run: - data = model.dequant(data) + data = model.dequant(data) # type: ignore[operator] return data return wrapper @@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) - compatibility.""" # todo: consider using registering hook before/after forward @functools.wraps(func) - def wrapper(data) -> Any: - data = model.quant(data) + def wrapper(data: Any) -> Any: + data = model.quant(data) # type: ignore[operator] data = func(data) - data = model.dequant(data) + data = model.dequant(data) # type: ignore[operator] return data return wrapper @@ -181,7 +184,9 @@ class QuantizationAwareTraining(Callback): ) self._observer_type = observer_type - if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)): + if collect_quantization is not None and not ( + isinstance(collect_quantization, int) or callable(collect_quantization) + ): raise MisconfigurationException( f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.' ) @@ -200,8 +205,8 @@ class QuantizationAwareTraining(Callback): self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages self._forward_calls = 0 - self._fake_quant_to_initial_state_dict = {} - self._last_fake_quant_to_observer_enabled = {} + self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {} + self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {} self._module_prepared = False def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool: @@ -227,7 +232,7 @@ class QuantizationAwareTraining(Callback): for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items(): fake_quant.observer_enabled.copy_(observer_enabled) - def _prepare_model(self, model: torch.nn.Module) -> None: + def _prepare_model(self, model: "pl.LightningModule") -> None: if self._module_prepared: return # QuantStub converts tensors from floating point to quantized @@ -237,7 +242,7 @@ class QuantizationAwareTraining(Callback): # manually specify where tensors will be converted from quantized # to floating point in the quantized model self.__module_forward = model.forward - model.forward = wrap_qat_forward_context( + model.forward = wrap_qat_forward_context( # type: ignore [assignment] quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization ) @@ -247,7 +252,7 @@ class QuantizationAwareTraining(Callback): if self._observer_type == "histogram": model.qconfig = torch.quantization.get_default_qconfig(self._qconfig) elif self._observer_type == "average": - extra_kwargs = {} + extra_kwargs: Dict[str, Optional[int]] = {} if _TORCH_GREATER_EQUAL_1_12: extra_kwargs["version"] = 0 # version=None corresponds to using FakeQuantize rather than @@ -258,7 +263,7 @@ class QuantizationAwareTraining(Callback): model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs) elif isinstance(self._qconfig, QConfig): - model.qconfig = self._qconfig + model.qconfig = self._qconfig # type: ignore [assignment] if self._check_feasible_fuse(model): fuse_modules(model, self._modules_to_fuse, inplace=True) @@ -273,12 +278,12 @@ class QuantizationAwareTraining(Callback): } self._module_prepared = True - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._prepare_model(pl_module) def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self._convert_on_fit_end: - pl_module.forward = self.__module_forward + pl_module.forward = self.__module_forward # type: ignore [assignment] return pl_module.eval() # Convert the observed model to a quantized model. This does several things: @@ -288,9 +293,12 @@ class QuantizationAwareTraining(Callback): torch.quantization.convert(pl_module, inplace=True) # check we shall preserve wrapper if self._input_compatible: - pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) + pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment] + model=pl_module, + func=self.__module_forward, + ) else: - pl_module.forward = self.__module_forward + pl_module.forward = self.__module_forward # type: ignore [assignment] def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if "train" in self._observer_disabled_stages: @@ -336,7 +344,7 @@ class QuantizationAwareTraining(Callback): keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"} return {n: getattr(self, n) for n in keys} - def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None: + def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None: """Special hook that gets called by the CheckpointConnector *before* the model gets loaded. This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 22f61c8453..e1dccd11a0 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -245,7 +245,7 @@ class CheckpointConnector: if state: # The Quantization callbacks have a special method that must be called before restoring the weights # of the model - callback._load_before_model(self.trainer.model, deepcopy(state)) + callback._load_before_model(self.trainer.lightning_module, deepcopy(state)) def restore_callbacks(self) -> None: """Restores all callbacks from the pre-loaded checkpoint."""