Add typing to `ModelPruning` callback (#7529)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
608de6abf4
commit
dbea5bb710
|
@ -19,12 +19,14 @@ import inspect
|
|||
import logging
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.utils.prune as pytorch_prune
|
||||
from torch import nn
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only
|
||||
|
@ -47,8 +49,9 @@ _PYTORCH_PRUNING_METHOD = {
|
|||
}
|
||||
|
||||
_PARAM_TUPLE = Tuple[nn.Module, str]
|
||||
_PARAM_LIST = Union[List[_PARAM_TUPLE], Tuple[_PARAM_TUPLE]]
|
||||
_PARAM_LIST = Sequence[_PARAM_TUPLE]
|
||||
_MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict)
|
||||
_LayerRef = TypedDict('_LayerRef', {'data': nn.Module, 'names': List[Tuple[int, str]]})
|
||||
|
||||
|
||||
class ModelPruning(Callback):
|
||||
|
@ -57,7 +60,7 @@ class ModelPruning(Callback):
|
|||
def __init__(
|
||||
self,
|
||||
pruning_fn: Union[Callable, str],
|
||||
parameters_to_prune: Optional[_PARAM_LIST] = None,
|
||||
parameters_to_prune: _PARAM_LIST = (),
|
||||
parameter_names: Optional[List[str]] = None,
|
||||
use_global_unstructured: bool = True,
|
||||
amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5,
|
||||
|
@ -153,9 +156,9 @@ class ModelPruning(Callback):
|
|||
self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
|
||||
self._resample_parameters = resample_parameters
|
||||
self._parameter_names = parameter_names or self.PARAMETER_NAMES
|
||||
self._global_kwargs = {}
|
||||
self._original_layers = None
|
||||
self._pruning_fn_name = None
|
||||
self._global_kwargs: Dict[str, Any] = {}
|
||||
self._original_layers: Optional[Dict[int, _LayerRef]] = None
|
||||
self._pruning_fn_name: Optional[str] = None
|
||||
|
||||
for name in self._parameter_names:
|
||||
if name not in self.PARAMETER_NAMES:
|
||||
|
@ -196,9 +199,10 @@ class ModelPruning(Callback):
|
|||
" HINT: if passing a `BasePruningMethod`, pass the the class, not an instance"
|
||||
)
|
||||
|
||||
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured":
|
||||
# need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attribute
|
||||
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": # type: ignore
|
||||
raise MisconfigurationException(
|
||||
'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.'
|
||||
'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' # type: ignore
|
||||
f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
|
||||
)
|
||||
|
||||
|
@ -206,7 +210,7 @@ class ModelPruning(Callback):
|
|||
self._apply_pruning = apply_pruning
|
||||
self._make_pruning_permanent = make_pruning_permanent
|
||||
|
||||
if not isinstance(amount, (int, float, Callable)):
|
||||
if not (isinstance(amount, (int, float)) or callable(amount)):
|
||||
raise MisconfigurationException(
|
||||
"`amount` should be provided and be either an int, a float or Callable function."
|
||||
)
|
||||
|
@ -218,13 +222,13 @@ class ModelPruning(Callback):
|
|||
|
||||
self._verbose = verbose
|
||||
|
||||
def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]:
|
||||
def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST:
|
||||
"""
|
||||
This function can be overridden to control which module to prune.
|
||||
"""
|
||||
return parameters_to_prune
|
||||
|
||||
def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytorch_prune.BasePruningMethod]:
|
||||
def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]:
|
||||
"""
|
||||
This function takes `pruning_fn`, a function name.
|
||||
|
||||
|
@ -232,11 +236,13 @@ class ModelPruning(Callback):
|
|||
ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
|
||||
|
||||
"""
|
||||
pruning_fn = (
|
||||
_PYTORCH_PRUNING_METHOD[pruning_fn]
|
||||
if self._use_global_unstructured else _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
|
||||
)
|
||||
assert callable(pruning_fn)
|
||||
if self._use_global_unstructured:
|
||||
pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn]
|
||||
self._global_kwargs = kwargs
|
||||
else:
|
||||
pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
|
||||
# save the function __name__ now because partial does not include it
|
||||
# and there are issues setting the attribute manually in ddp.
|
||||
self._pruning_fn_name = pruning_fn.__name__
|
||||
|
@ -245,10 +251,10 @@ class ModelPruning(Callback):
|
|||
return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_pruning_fn(pruning_fn, **kwargs):
|
||||
def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable:
|
||||
return partial(pruning_fn, **kwargs)
|
||||
|
||||
def make_pruning_permanent(self, pl_module: LightningModule):
|
||||
def make_pruning_permanent(self, pl_module: LightningModule) -> None:
|
||||
"""
|
||||
Removes pruning buffers from any pruned modules
|
||||
|
||||
|
@ -261,14 +267,14 @@ class ModelPruning(Callback):
|
|||
hook.remove(module)
|
||||
del module._forward_pre_hooks[k]
|
||||
|
||||
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
|
||||
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str) -> None:
|
||||
trained = getattr(module, tensor_name)
|
||||
orig = getattr(orig_module, tensor_name)
|
||||
if trained is None or orig is None:
|
||||
return
|
||||
trained.data = orig.data.to(trained.device)
|
||||
|
||||
def apply_lottery_ticket_hypothesis(self):
|
||||
def apply_lottery_ticket_hypothesis(self) -> None:
|
||||
r"""
|
||||
Lottery ticket hypothesis algorithm (see page 2 of the paper):
|
||||
|
||||
|
@ -282,33 +288,35 @@ class ModelPruning(Callback):
|
|||
The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta`
|
||||
""" # noqa: E501
|
||||
|
||||
def copy_param(new, old, name: str) -> None:
|
||||
def copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
|
||||
dst = getattr(new, name)
|
||||
src = getattr(old, name)
|
||||
if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor):
|
||||
return
|
||||
dst.data = src.data.to(dst.device)
|
||||
|
||||
assert self._original_layers is not None
|
||||
for d in self._original_layers.values():
|
||||
copy, names = d["data"], d["names"]
|
||||
if self._resample_parameters and hasattr(copy, "reset_parameters"):
|
||||
copy = d["data"]
|
||||
names = d["names"]
|
||||
if self._resample_parameters and hasattr(copy, "reset_parameters") and callable(copy.reset_parameters):
|
||||
copy = deepcopy(copy) # keep the original parameters
|
||||
copy.reset_parameters()
|
||||
for i, name in names:
|
||||
new, new_name = self._parameters_to_prune[i]
|
||||
copy_param(new, copy, name)
|
||||
|
||||
def _apply_local_pruning(self, amount: float):
|
||||
def _apply_local_pruning(self, amount: float) -> None:
|
||||
for module, name in self._parameters_to_prune:
|
||||
self.pruning_fn(module, name=name, amount=amount)
|
||||
|
||||
def _resolve_global_kwargs(self, amount: float):
|
||||
def _resolve_global_kwargs(self, amount: float) -> Dict[str, Any]:
|
||||
self._global_kwargs["amount"] = amount
|
||||
params = set(inspect.signature(self.pruning_fn).parameters)
|
||||
params.discard("self")
|
||||
return {k: v for k, v in self._global_kwargs.items() if k in params}
|
||||
|
||||
def _apply_global_pruning(self, amount: float):
|
||||
def _apply_global_pruning(self, amount: float) -> None:
|
||||
pytorch_prune.global_unstructured(
|
||||
self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount)
|
||||
)
|
||||
|
@ -321,7 +329,7 @@ class ModelPruning(Callback):
|
|||
mask = getattr(module, attr)
|
||||
return (mask == 0).sum().item(), mask.numel()
|
||||
|
||||
def apply_pruning(self, amount: Union[int, float]):
|
||||
def apply_pruning(self, amount: Union[int, float]) -> None:
|
||||
""" Applies pruning to ``parameters_to_prune``. """
|
||||
if self._verbose:
|
||||
prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
|
||||
|
@ -338,7 +346,7 @@ class ModelPruning(Callback):
|
|||
@rank_zero_only
|
||||
def _log_sparsity_stats(
|
||||
self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0
|
||||
):
|
||||
) -> None:
|
||||
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
|
||||
prev_total_zeros = sum(zeros for zeros, _ in prev)
|
||||
curr_total_zeros = sum(zeros for zeros, _ in curr)
|
||||
|
@ -357,7 +365,7 @@ class ModelPruning(Callback):
|
|||
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
|
||||
)
|
||||
|
||||
def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule):
|
||||
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None:
|
||||
parameters_to_prune = self.sanitize_parameters_to_prune(
|
||||
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
|
||||
)
|
||||
|
@ -370,29 +378,34 @@ class ModelPruning(Callback):
|
|||
self._original_layers = {}
|
||||
for i, (module, name) in enumerate(self._parameters_to_prune):
|
||||
id_ = id(module)
|
||||
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
|
||||
self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[]))
|
||||
self._original_layers[id_]["names"].append((i, name))
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module: LightningModule):
|
||||
current_epoch = trainer.current_epoch
|
||||
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
|
||||
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
|
||||
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None: # type: ignore
|
||||
current_epoch = pl_module.current_epoch
|
||||
prune = self._apply_pruning(current_epoch) if callable(self._apply_pruning) else self._apply_pruning
|
||||
amount = self.amount(current_epoch) if callable(self.amount) else self.amount
|
||||
if not prune or not amount:
|
||||
return
|
||||
self.apply_pruning(amount)
|
||||
|
||||
if (
|
||||
self._use_lottery_ticket_hypothesis(current_epoch)
|
||||
if isinstance(self._use_lottery_ticket_hypothesis, Callable) else self._use_lottery_ticket_hypothesis
|
||||
if callable(self._use_lottery_ticket_hypothesis) else self._use_lottery_ticket_hypothesis
|
||||
):
|
||||
self.apply_lottery_ticket_hypothesis()
|
||||
|
||||
def on_train_end(self, trainer, pl_module: LightningModule):
|
||||
def on_train_end(self, trainer: 'pl.Trainer', pl_module: LightningModule) -> None:
|
||||
if self._make_pruning_permanent:
|
||||
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
|
||||
self.make_pruning_permanent(pl_module)
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
|
||||
def on_save_checkpoint(
|
||||
self,
|
||||
trainer: 'pl.Trainer',
|
||||
pl_module: LightningModule,
|
||||
checkpoint: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
if self._make_pruning_permanent:
|
||||
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
|
||||
prev_device = pl_module.device
|
||||
|
@ -402,11 +415,13 @@ class ModelPruning(Callback):
|
|||
checkpoint["state_dict"] = copy.state_dict()
|
||||
pl_module.to(prev_device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
@staticmethod
|
||||
def sanitize_parameters_to_prune(
|
||||
pl_module: LightningModule,
|
||||
parameters_to_prune: Optional[_PARAM_LIST] = None,
|
||||
parameter_names: Optional[List[str]] = None,
|
||||
parameters_to_prune: _PARAM_LIST = (),
|
||||
parameter_names: Sequence[str] = (),
|
||||
) -> _PARAM_LIST:
|
||||
"""
|
||||
This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``.
|
||||
|
@ -415,13 +430,13 @@ class ModelPruning(Callback):
|
|||
Raises:
|
||||
MisconfigurationException:
|
||||
If ``parameters_to_prune`` doesn't exist in the model, or
|
||||
if ``parameters_to_prune`` is neither a list of tuple nor ``None``.
|
||||
if ``parameters_to_prune`` is neither a list nor a tuple.
|
||||
"""
|
||||
parameters = parameter_names or ModelPruning.PARAMETER_NAMES
|
||||
|
||||
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]
|
||||
|
||||
if parameters_to_prune is None:
|
||||
if not parameters_to_prune:
|
||||
parameters_to_prune = [(m, p) for p in parameters for m in current_modules
|
||||
if getattr(m, p, None) is not None]
|
||||
elif (
|
||||
|
|
|
@ -10,3 +10,4 @@ tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file in
|
|||
torchmetrics>=0.2.0
|
||||
pyDeprecate==0.3.0
|
||||
packaging
|
||||
typing-extensions # TypedDict support for python<3.8
|
||||
|
|
|
@ -117,6 +117,9 @@ disable_error_code = attr-defined
|
|||
# todo: add proper typing to this module...
|
||||
[mypy-pytorch_lightning.callbacks.*]
|
||||
ignore_errors = True
|
||||
# whitelist
|
||||
[mypy-pytorch_lightning.callbacks.pruning]
|
||||
ignore_errors = False
|
||||
|
||||
# todo: add proper typing to this module...
|
||||
[mypy-pytorch_lightning.core.*]
|
||||
|
|
Loading…
Reference in New Issue