From dbea5bb710a02f9a62f6f06accdd50796c52e65f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 19 May 2021 22:01:42 +0200 Subject: [PATCH] Add typing to `ModelPruning` callback (#7529) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/callbacks/pruning.py | 93 +++++++++++++++----------- requirements.txt | 1 + setup.cfg | 3 + 3 files changed, 58 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 715fa14a41..b044fc3e6e 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -19,12 +19,14 @@ import inspect import logging from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn.utils.prune as pytorch_prune from torch import nn +from typing_extensions import TypedDict +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only @@ -47,8 +49,9 @@ _PYTORCH_PRUNING_METHOD = { } _PARAM_TUPLE = Tuple[nn.Module, str] -_PARAM_LIST = Union[List[_PARAM_TUPLE], Tuple[_PARAM_TUPLE]] +_PARAM_LIST = Sequence[_PARAM_TUPLE] _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) +_LayerRef = TypedDict('_LayerRef', {'data': nn.Module, 'names': List[Tuple[int, str]]}) class ModelPruning(Callback): @@ -57,7 +60,7 @@ class ModelPruning(Callback): def __init__( self, pruning_fn: Union[Callable, str], - parameters_to_prune: Optional[_PARAM_LIST] = None, + parameters_to_prune: _PARAM_LIST = (), parameter_names: Optional[List[str]] = None, use_global_unstructured: bool = True, amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, @@ -153,9 +156,9 @@ class ModelPruning(Callback): self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis self._resample_parameters = resample_parameters self._parameter_names = parameter_names or self.PARAMETER_NAMES - self._global_kwargs = {} - self._original_layers = None - self._pruning_fn_name = None + self._global_kwargs: Dict[str, Any] = {} + self._original_layers: Optional[Dict[int, _LayerRef]] = None + self._pruning_fn_name: Optional[str] = None for name in self._parameter_names: if name not in self.PARAMETER_NAMES: @@ -196,9 +199,10 @@ class ModelPruning(Callback): " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance" ) - if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": + # need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute + if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore raise MisconfigurationException( - 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' + 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " ) @@ -206,7 +210,7 @@ class ModelPruning(Callback): self._apply_pruning = apply_pruning self._make_pruning_permanent = make_pruning_permanent - if not isinstance(amount, (int, float, Callable)): + if not (isinstance(amount, (int, float)) or callable(amount)): raise MisconfigurationException( "`amount` should be provided and be either an int, a float or Callable function." ) @@ -218,13 +222,13 @@ class ModelPruning(Callback): self._verbose = verbose - def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]: + def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST: """ This function can be overridden to control which module to prune. """ return parameters_to_prune - def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytorch_prune.BasePruningMethod]: + def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]: """ This function takes `pruning_fn`, a function name. @@ -232,11 +236,13 @@ class ModelPruning(Callback): ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. """ + pruning_fn = ( + _PYTORCH_PRUNING_METHOD[pruning_fn] + if self._use_global_unstructured else _PYTORCH_PRUNING_FUNCTIONS[pruning_fn] + ) + assert callable(pruning_fn) if self._use_global_unstructured: - pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn] self._global_kwargs = kwargs - else: - pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn] # save the function __name__ now because partial does not include it # and there are issues setting the attribute manually in ddp. self._pruning_fn_name = pruning_fn.__name__ @@ -245,10 +251,10 @@ class ModelPruning(Callback): return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs) @staticmethod - def _wrap_pruning_fn(pruning_fn, **kwargs): + def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable: return partial(pruning_fn, **kwargs) - def make_pruning_permanent(self, pl_module: LightningModule): + def make_pruning_permanent(self, pl_module: LightningModule) -> None: """ Removes pruning buffers from any pruned modules @@ -261,14 +267,14 @@ class ModelPruning(Callback): hook.remove(module) del module._forward_pre_hooks[k] - def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): + def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str) -> None: trained = getattr(module, tensor_name) orig = getattr(orig_module, tensor_name) if trained is None or orig is None: return trained.data = orig.data.to(trained.device) - def apply_lottery_ticket_hypothesis(self): + def apply_lottery_ticket_hypothesis(self) -> None: r""" Lottery ticket hypothesis algorithm (see page 2 of the paper): @@ -282,33 +288,35 @@ class ModelPruning(Callback): The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` """ # noqa: E501 - def copy_param(new, old, name: str) -> None: + def copy_param(new: nn.Module, old: nn.Module, name: str) -> None: dst = getattr(new, name) src = getattr(old, name) if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor): return dst.data = src.data.to(dst.device) + assert self._original_layers is not None for d in self._original_layers.values(): - copy, names = d["data"], d["names"] - if self._resample_parameters and hasattr(copy, "reset_parameters"): + copy = d["data"] + names = d["names"] + if self._resample_parameters and hasattr(copy, "reset_parameters") and callable(copy.reset_parameters): copy = deepcopy(copy) # keep the original parameters copy.reset_parameters() for i, name in names: new, new_name = self._parameters_to_prune[i] copy_param(new, copy, name) - def _apply_local_pruning(self, amount: float): + def _apply_local_pruning(self, amount: float) -> None: for module, name in self._parameters_to_prune: self.pruning_fn(module, name=name, amount=amount) - def _resolve_global_kwargs(self, amount: float): + def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]: self._global_kwargs["amount"] = amount params = set(inspect.signature(self.pruning_fn).parameters) params.discard("self") return {k: v for k, v in self._global_kwargs.items() if k in params} - def _apply_global_pruning(self, amount: float): + def _apply_global_pruning(self, amount: float) -> None: pytorch_prune.global_unstructured( self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) ) @@ -321,7 +329,7 @@ class ModelPruning(Callback): mask = getattr(module, attr) return (mask == 0).sum().item(), mask.numel() - def apply_pruning(self, amount: Union[int, float]): + def apply_pruning(self, amount: Union[int, float]) -> None: """ Applies pruning to ``parameters_to_prune``. """ if self._verbose: prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] @@ -338,7 +346,7 @@ class ModelPruning(Callback): @rank_zero_only def _log_sparsity_stats( self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 - ): + ) -> None: total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) prev_total_zeros = sum(zeros for zeros, _ in prev) curr_total_zeros = sum(zeros for zeros, _ in curr) @@ -357,7 +365,7 @@ class ModelPruning(Callback): f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) - def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule): + def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names ) @@ -370,29 +378,34 @@ class ModelPruning(Callback): self._original_layers = {} for i, (module, name) in enumerate(self._parameters_to_prune): id_ = id(module) - self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []}) + self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[])) self._original_layers[id_]["names"].append((i, name)) - def on_train_epoch_end(self, trainer, pl_module: LightningModule): - current_epoch = trainer.current_epoch - prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning - amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount + def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: # type: ignore + current_epoch = pl_module.current_epoch + prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning + amount = self.amount(current_epoch) if callable(self.amount) else self.amount if not prune or not amount: return self.apply_pruning(amount) if ( self._use_lottery_ticket_hypothesis(current_epoch) - if isinstance(self._use_lottery_ticket_hypothesis, Callable) else self._use_lottery_ticket_hypothesis + if callable(self._use_lottery_ticket_hypothesis) else self._use_lottery_ticket_hypothesis ): self.apply_lottery_ticket_hypothesis() - def on_train_end(self, trainer, pl_module: LightningModule): + def on_train_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.") self.make_pruning_permanent(pl_module) - def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]): + def on_save_checkpoint( + self, + trainer: 'pl.Trainer', + pl_module: LightningModule, + checkpoint: Dict[str, Any], + ) -> Dict[str, Any]: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.") prev_device = pl_module.device @@ -402,11 +415,13 @@ class ModelPruning(Callback): checkpoint["state_dict"] = copy.state_dict() pl_module.to(prev_device) + return checkpoint + @staticmethod def sanitize_parameters_to_prune( pl_module: LightningModule, - parameters_to_prune: Optional[_PARAM_LIST] = None, - parameter_names: Optional[List[str]] = None, + parameters_to_prune: _PARAM_LIST = (), + parameter_names: Sequence[str] = (), ) -> _PARAM_LIST: """ This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. @@ -415,13 +430,13 @@ class ModelPruning(Callback): Raises: MisconfigurationException: If ``parameters_to_prune`` doesn't exist in the model, or - if ``parameters_to_prune`` is neither a list of tuple nor ``None``. + if ``parameters_to_prune`` is neither a list nor a tuple. """ parameters = parameter_names or ModelPruning.PARAMETER_NAMES current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)] - if parameters_to_prune is None: + if not parameters_to_prune: parameters_to_prune = [(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None] elif ( diff --git a/requirements.txt b/requirements.txt index c3a4caaf64..964bb493a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file in torchmetrics>=0.2.0 pyDeprecate==0.3.0 packaging +typing-extensions # TypedDict support for python<3.8 diff --git a/setup.cfg b/setup.cfg index ba4ee69da5..5a68adb27b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -117,6 +117,9 @@ disable_error_code = attr-defined # todo: add proper typing to this module... [mypy-pytorch_lightning.callbacks.*] ignore_errors = True +# whitelist +[mypy-pytorch_lightning.callbacks.pruning] +ignore_errors = False # todo: add proper typing to this module... [mypy-pytorch_lightning.core.*]