Add typing to `ModelPruning` callback (#7529)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-05-19 22:01:42 +02:00 committed by GitHub
parent 608de6abf4
commit dbea5bb710
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 39 deletions

View File

@ -19,12 +19,14 @@ import inspect
import logging
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.utils.prune as pytorch_prune
from torch import nn
from typing_extensions import TypedDict
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only
@ -47,8 +49,9 @@ _PYTORCH_PRUNING_METHOD = {
}
_PARAM_TUPLE = Tuple[nn.Module, str]
_PARAM_LIST = Union[List[_PARAM_TUPLE], Tuple[_PARAM_TUPLE]]
_PARAM_LIST = Sequence[_PARAM_TUPLE]
_MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict)
_LayerRef = TypedDict('_LayerRef', {'data': nn.Module, 'names': List[Tuple[int, str]]})
class ModelPruning(Callback):
@ -57,7 +60,7 @@ class ModelPruning(Callback):
def __init__(
self,
pruning_fn: Union[Callable, str],
parameters_to_prune: Optional[_PARAM_LIST] = None,
parameters_to_prune: _PARAM_LIST = (),
parameter_names: Optional[List[str]] = None,
use_global_unstructured: bool = True,
amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5,
@ -153,9 +156,9 @@ class ModelPruning(Callback):
self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
self._resample_parameters = resample_parameters
self._parameter_names = parameter_names or self.PARAMETER_NAMES
self._global_kwargs = {}
self._original_layers = None
self._pruning_fn_name = None
self._global_kwargs: Dict[str, Any] = {}
self._original_layers: Optional[Dict[int, _LayerRef]] = None
self._pruning_fn_name: Optional[str] = None
for name in self._parameter_names:
if name not in self.PARAMETER_NAMES:
@ -196,9 +199,10 @@ class ModelPruning(Callback):
" HINT: if passing a `BasePruningMethod`, pass the the class, not an instance"
)
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured":
# need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore
raise MisconfigurationException(
'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.'
'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore
f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
)
@ -206,7 +210,7 @@ class ModelPruning(Callback):
self._apply_pruning = apply_pruning
self._make_pruning_permanent = make_pruning_permanent
if not isinstance(amount, (int, float, Callable)):
if not (isinstance(amount, (int, float)) or callable(amount)):
raise MisconfigurationException(
"`amount` should be provided and be either an int, a float or Callable function."
)
@ -218,13 +222,13 @@ class ModelPruning(Callback):
self._verbose = verbose
def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]:
def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST:
"""
This function can be overridden to control which module to prune.
"""
return parameters_to_prune
def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytorch_prune.BasePruningMethod]:
def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]:
"""
This function takes `pruning_fn`, a function name.
@ -232,11 +236,13 @@ class ModelPruning(Callback):
ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
"""
pruning_fn = (
_PYTORCH_PRUNING_METHOD[pruning_fn]
if self._use_global_unstructured else _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
)
assert callable(pruning_fn)
if self._use_global_unstructured:
pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn]
self._global_kwargs = kwargs
else:
pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
# save the function __name__ now because partial does not include it
# and there are issues setting the attribute manually in ddp.
self._pruning_fn_name = pruning_fn.__name__
@ -245,10 +251,10 @@ class ModelPruning(Callback):
return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs)
@staticmethod
def _wrap_pruning_fn(pruning_fn, **kwargs):
def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable:
return partial(pruning_fn, **kwargs)
def make_pruning_permanent(self, pl_module: LightningModule):
def make_pruning_permanent(self, pl_module: LightningModule) -> None:
"""
Removes pruning buffers from any pruned modules
@ -261,14 +267,14 @@ class ModelPruning(Callback):
hook.remove(module)
del module._forward_pre_hooks[k]
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str) -> None:
trained = getattr(module, tensor_name)
orig = getattr(orig_module, tensor_name)
if trained is None or orig is None:
return
trained.data = orig.data.to(trained.device)
def apply_lottery_ticket_hypothesis(self):
def apply_lottery_ticket_hypothesis(self) -> None:
r"""
Lottery ticket hypothesis algorithm (see page 2 of the paper):
@ -282,33 +288,35 @@ class ModelPruning(Callback):
The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta`
""" # noqa: E501
def copy_param(new, old, name: str) -> None:
def copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
dst = getattr(new, name)
src = getattr(old, name)
if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor):
return
dst.data = src.data.to(dst.device)
assert self._original_layers is not None
for d in self._original_layers.values():
copy, names = d["data"], d["names"]
if self._resample_parameters and hasattr(copy, "reset_parameters"):
copy = d["data"]
names = d["names"]
if self._resample_parameters and hasattr(copy, "reset_parameters") and callable(copy.reset_parameters):
copy = deepcopy(copy) # keep the original parameters
copy.reset_parameters()
for i, name in names:
new, new_name = self._parameters_to_prune[i]
copy_param(new, copy, name)
def _apply_local_pruning(self, amount: float):
def _apply_local_pruning(self, amount: float) -> None:
for module, name in self._parameters_to_prune:
self.pruning_fn(module, name=name, amount=amount)
def _resolve_global_kwargs(self, amount: float):
def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]:
self._global_kwargs["amount"] = amount
params = set(inspect.signature(self.pruning_fn).parameters)
params.discard("self")
return {k: v for k, v in self._global_kwargs.items() if k in params}
def _apply_global_pruning(self, amount: float):
def _apply_global_pruning(self, amount: float) -> None:
pytorch_prune.global_unstructured(
self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount)
)
@ -321,7 +329,7 @@ class ModelPruning(Callback):
mask = getattr(module, attr)
return (mask == 0).sum().item(), mask.numel()
def apply_pruning(self, amount: Union[int, float]):
def apply_pruning(self, amount: Union[int, float]) -> None:
""" Applies pruning to ``parameters_to_prune``. """
if self._verbose:
prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
@ -338,7 +346,7 @@ class ModelPruning(Callback):
@rank_zero_only
def _log_sparsity_stats(
self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0
):
) -> None:
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
prev_total_zeros = sum(zeros for zeros, _ in prev)
curr_total_zeros = sum(zeros for zeros, _ in curr)
@ -357,7 +365,7 @@ class ModelPruning(Callback):
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
)
def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule):
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None:
parameters_to_prune = self.sanitize_parameters_to_prune(
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
)
@ -370,29 +378,34 @@ class ModelPruning(Callback):
self._original_layers = {}
for i, (module, name) in enumerate(self._parameters_to_prune):
id_ = id(module)
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[]))
self._original_layers[id_]["names"].append((i, name))
def on_train_epoch_end(self, trainer, pl_module: LightningModule):
current_epoch = trainer.current_epoch
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: # type: ignore
current_epoch = pl_module.current_epoch
prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning
amount = self.amount(current_epoch) if callable(self.amount) else self.amount
if not prune or not amount:
return
self.apply_pruning(amount)
if (
self._use_lottery_ticket_hypothesis(current_epoch)
if isinstance(self._use_lottery_ticket_hypothesis, Callable) else self._use_lottery_ticket_hypothesis
if callable(self._use_lottery_ticket_hypothesis) else self._use_lottery_ticket_hypothesis
):
self.apply_lottery_ticket_hypothesis()
def on_train_end(self, trainer, pl_module: LightningModule):
def on_train_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None:
if self._make_pruning_permanent:
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
self.make_pruning_permanent(pl_module)
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
def on_save_checkpoint(
self,
trainer: 'pl.Trainer',
pl_module: LightningModule,
checkpoint: Dict[str, Any],
) -> Dict[str, Any]:
if self._make_pruning_permanent:
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
prev_device = pl_module.device
@ -402,11 +415,13 @@ class ModelPruning(Callback):
checkpoint["state_dict"] = copy.state_dict()
pl_module.to(prev_device)
return checkpoint
@staticmethod
def sanitize_parameters_to_prune(
pl_module: LightningModule,
parameters_to_prune: Optional[_PARAM_LIST] = None,
parameter_names: Optional[List[str]] = None,
parameters_to_prune: _PARAM_LIST = (),
parameter_names: Sequence[str] = (),
) -> _PARAM_LIST:
"""
This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``.
@ -415,13 +430,13 @@ class ModelPruning(Callback):
Raises:
MisconfigurationException:
If ``parameters_to_prune`` doesn't exist in the model, or
if ``parameters_to_prune`` is neither a list of tuple nor ``None``.
if ``parameters_to_prune`` is neither a list nor a tuple.
"""
parameters = parameter_names or ModelPruning.PARAMETER_NAMES
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]
if parameters_to_prune is None:
if not parameters_to_prune:
parameters_to_prune = [(m, p) for p in parameters for m in current_modules
if getattr(m, p, None) is not None]
elif (

View File

@ -10,3 +10,4 @@ tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file in
torchmetrics>=0.2.0
pyDeprecate==0.3.0
packaging
typing-extensions # TypedDict support for python<3.8

View File

@ -117,6 +117,9 @@ disable_error_code = attr-defined
# todo: add proper typing to this module...
[mypy-pytorch_lightning.callbacks.*]
ignore_errors = True
# whitelist
[mypy-pytorch_lightning.callbacks.pruning]
ignore_errors = False
# todo: add proper typing to this module...
[mypy-pytorch_lightning.core.*]