391 lines
15 KiB
Python
391 lines
15 KiB
Python
![]() |
# Copyright The PyTorch Lightning team.
|
|||
|
#
|
|||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
# you may not use this file except in compliance with the License.
|
|||
|
# You may obtain a copy of the License at
|
|||
|
#
|
|||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
#
|
|||
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
# See the License for the specific language governing permissions and
|
|||
|
# limitations under the License.
|
|||
|
|
|||
|
r"""
|
|||
|
ModelPruning
|
|||
|
^^^^^^^^^^^^
|
|||
|
|
|||
|
"""
|
|||
|
import inspect
|
|||
|
from copy import deepcopy
|
|||
|
from functools import partial
|
|||
|
from typing import Callable, List, Optional, Tuple, Union
|
|||
|
|
|||
|
from torch import nn
|
|||
|
from torch.nn.modules.container import ModuleDict, ModuleList
|
|||
|
|
|||
|
import pytorch_lightning as pl
|
|||
|
from pytorch_lightning.callbacks.base import Callback
|
|||
|
from pytorch_lightning.core.lightning import LightningModule
|
|||
|
from pytorch_lightning.utilities import _PYTORCH_PRUNE_AVAILABLE
|
|||
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
|
|
|||
|
if _PYTORCH_PRUNE_AVAILABLE:
|
|||
|
import torch.nn.utils.prune as pytorch_prune
|
|||
|
|
|||
|
|
|||
|
_PYTORCH_PRUNING_FUNCTIONS = {
|
|||
|
"ln_structured": pytorch_prune.ln_structured,
|
|||
|
"l1_unstructured": pytorch_prune.l1_unstructured,
|
|||
|
"random_structured": pytorch_prune.random_structured,
|
|||
|
"random_unstructured": pytorch_prune.random_unstructured,
|
|||
|
}
|
|||
|
|
|||
|
_PYTORCH_PRUNING_METHOD = {
|
|||
|
"ln_structured": pytorch_prune.LnStructured,
|
|||
|
"l1_unstructured": pytorch_prune.L1Unstructured,
|
|||
|
"random_structured": pytorch_prune.RandomStructured,
|
|||
|
"random_unstructured": pytorch_prune.RandomUnstructured,
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
class ModelPruning(Callback):
|
|||
|
|
|||
|
PARAMETER_NAMES = ("weight", "bias")
|
|||
|
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
pruning_fn: Union[Callable, str] = None,
|
|||
|
parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None,
|
|||
|
parameter_names: List[str] = ["weight"],
|
|||
|
use_global_unstructured: bool = True,
|
|||
|
amount: Optional[Union[int, float]] = 0.5,
|
|||
|
make_pruning_permanent: Optional[bool] = True,
|
|||
|
use_lottery_ticket_hypothesis: Optional[bool] = True,
|
|||
|
pruning_dim: Optional[int] = None,
|
|||
|
pruning_norm: Optional[int] = None,
|
|||
|
) -> None:
|
|||
|
"""
|
|||
|
|
|||
|
Pruning Callback relying on PyTorch prune utils.
|
|||
|
|
|||
|
This callback is responsible to prune networks parameters
|
|||
|
during your training.
|
|||
|
|
|||
|
Find here the PyTorch (Pruning Tutorial)[https://pytorch.org/tutorials/intermediate/pruning_tutorial.html]
|
|||
|
|
|||
|
.. code-block:: python
|
|||
|
|
|||
|
parameters_to_prune = [
|
|||
|
(model.mlp_1, "weight"),
|
|||
|
(model.mlp_2, "weight")
|
|||
|
]
|
|||
|
|
|||
|
trainer = Trainer(
|
|||
|
callbacks=[
|
|||
|
ModelPruning(
|
|||
|
pruning_fn='l1_unstructured',
|
|||
|
parameters_to_prune=parameters_to_prune,
|
|||
|
amount=0.01,
|
|||
|
use_global_unstructured=True,
|
|||
|
)
|
|||
|
]
|
|||
|
)
|
|||
|
|
|||
|
When `parameters_to_prune` is None, `parameters_to_prune` will contains all parameters from the model.
|
|||
|
The user can override `filter_parameters_to_prune` to filter any nn.Module to be pruned.
|
|||
|
|
|||
|
Args:
|
|||
|
|
|||
|
pruning_fn: function from torch.nn.utils.prune module
|
|||
|
or your based own subclasses from PyTorch ``BasePruningMethod``.
|
|||
|
Can be string e.g. `"l1_unstructured"`.
|
|||
|
See pytorch docs for more details.
|
|||
|
|
|||
|
parameters_to_prune: list of strings or list of tuple with
|
|||
|
nn.Module and its associated string name parameters.
|
|||
|
|
|||
|
parameter_names: List of parameter names to be used from nn.Module.
|
|||
|
Can either be `weight` or `bias`.
|
|||
|
|
|||
|
use_global_unstructured: Whether to apply pruning globally on the model.
|
|||
|
If parameters_to_prune is provided, global_unstructured will be restricted on them.
|
|||
|
|
|||
|
amount: quantity of parameters to prune:
|
|||
|
|
|||
|
- float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune.
|
|||
|
- int, it represents the absolute number of parameters to prune.
|
|||
|
- Callable, the function will be called on every epoch.
|
|||
|
|
|||
|
make_pruning_permanent: if True then all
|
|||
|
reparametrization pre-hooks and tensors with mask
|
|||
|
will be removed on fit end.
|
|||
|
|
|||
|
use_lottery_ticket_hypothesis: Wether to use algorithm describes in
|
|||
|
"The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf)
|
|||
|
|
|||
|
pruning_dim: if you are using structured pruning method you need
|
|||
|
to specify dimension.
|
|||
|
|
|||
|
pruning_norm: if you are using ln_structured you need to specify norm.
|
|||
|
|
|||
|
"""
|
|||
|
|
|||
|
self._use_global_unstructured = use_global_unstructured
|
|||
|
self._parameters_to_prune = parameters_to_prune
|
|||
|
self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
|
|||
|
self._parameter_names = parameter_names or self.PARAMETER_NAMES
|
|||
|
self._global_kwargs = {}
|
|||
|
self._initial_parameters_to_prune = None
|
|||
|
|
|||
|
for param_name in self._parameter_names:
|
|||
|
if param_name not in self.PARAMETER_NAMES:
|
|||
|
raise MisconfigurationException(
|
|||
|
f"The provided parameter_names {param_name} isn't in {self.PARAMETER_NAMES} "
|
|||
|
)
|
|||
|
|
|||
|
if isinstance(pruning_fn, str):
|
|||
|
pruning_fn = pruning_fn.lower()
|
|||
|
if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS:
|
|||
|
raise MisconfigurationException(
|
|||
|
f"The provided pruning_fn {pruning_fn} isn't available with "
|
|||
|
f"PyTorch build-in {_PYTORCH_PRUNING_FUNCTIONS.keys()} "
|
|||
|
)
|
|||
|
if "unstructured" not in pruning_fn:
|
|||
|
if pruning_dim is None:
|
|||
|
raise MisconfigurationException(
|
|||
|
"When requesting `structured` pruning, the `pruning_dim` should be provided."
|
|||
|
)
|
|||
|
if pruning_fn == "ln_structured":
|
|||
|
if pruning_norm is None:
|
|||
|
raise MisconfigurationException(
|
|||
|
"When requesting `ln_structured` pruning, the `pruning_norm` should be provided."
|
|||
|
)
|
|||
|
|
|||
|
pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim, n=pruning_norm)
|
|||
|
else:
|
|||
|
pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim)
|
|||
|
else:
|
|||
|
pruning_fn = self._create_pruning_fn(pruning_fn)
|
|||
|
|
|||
|
else:
|
|||
|
bases = getattr(pruning_fn, "__bases__", None)
|
|||
|
if bases is None or bases[0] != pytorch_prune.BasePruningMethod:
|
|||
|
raise MisconfigurationException(
|
|||
|
f'pruning_fn is expected to be the str in {_PYTORCH_PRUNING_FUNCTIONS.keys()} '
|
|||
|
f'or a `PyTorch BasePruningMethod`. Found: {pruning_fn}'
|
|||
|
)
|
|||
|
|
|||
|
if not use_global_unstructured:
|
|||
|
raise MisconfigurationException(
|
|||
|
'`PyTorch BasePruningMethod` is currently support only for `use_global_unstructured=True`. ')
|
|||
|
|
|||
|
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured":
|
|||
|
raise MisconfigurationException(
|
|||
|
'Only "unstructured" PRUNING_TYPE supported for '
|
|||
|
f"the `pruning_method`. Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
|
|||
|
)
|
|||
|
|
|||
|
self.pruning_fn = pruning_fn
|
|||
|
|
|||
|
self.make_pruning_permanent = make_pruning_permanent
|
|||
|
|
|||
|
if not isinstance(amount, (int, float, Callable)):
|
|||
|
raise MisconfigurationException(
|
|||
|
"amount should be provided and be either an int, a float or Callable function."
|
|||
|
)
|
|||
|
|
|||
|
self.amount = amount
|
|||
|
|
|||
|
def filter_parameters_to_prune(self, parameters_to_prune: Optional[List[Tuple[nn.Module, str]]]):
|
|||
|
"""
|
|||
|
This function can be overriden to control which module to prune.
|
|||
|
"""
|
|||
|
return parameters_to_prune
|
|||
|
|
|||
|
def _create_pruning_fn(self, pruning_fn: str, *args, **kwargs):
|
|||
|
"""
|
|||
|
This function takes `pruning_fn`, a function name.
|
|||
|
|
|||
|
IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod``
|
|||
|
ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
|
|||
|
|
|||
|
"""
|
|||
|
if self._use_global_unstructured:
|
|||
|
pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn]
|
|||
|
self._global_kwargs = kwargs
|
|||
|
return pruning_fn
|
|||
|
else:
|
|||
|
return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs)
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _wrap_pruning_fn(pruning_fn, *args, **kwargs):
|
|||
|
return partial(pruning_fn, **kwargs)
|
|||
|
|
|||
|
def _make_pruning_permanent(self):
|
|||
|
for module, param_name in self._parameters_to_prune:
|
|||
|
pytorch_prune.remove(module, param_name)
|
|||
|
|
|||
|
def _resolve_amount(self, current_epoch: int) -> float:
|
|||
|
if isinstance(self.amount, Callable):
|
|||
|
amount_fn = self.amount
|
|||
|
amount = amount_fn(current_epoch)
|
|||
|
else:
|
|||
|
amount = self.amount
|
|||
|
return amount
|
|||
|
|
|||
|
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
|
|||
|
"""
|
|||
|
"The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf) algorithm:
|
|||
|
|
|||
|
1. Randomly initialize a neural network f(x;θ0)(where θ0 ∼Dθ).
|
|||
|
2. Train the network for j iterations, arriving at parameters θj .
|
|||
|
3. Prune p% of the parameters in θj , creating a mask m.
|
|||
|
4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; m⊙θ0).
|
|||
|
|
|||
|
This function is responsible of step 4.
|
|||
|
"""
|
|||
|
trained = getattr(module, tensor_name)
|
|||
|
orig = getattr(orig_module, tensor_name)
|
|||
|
if trained is None or orig is None:
|
|||
|
return
|
|||
|
trained.data = orig.data.to(trained.device)
|
|||
|
|
|||
|
def apply_lottery_ticket_hypothesis(self):
|
|||
|
for (mod, tensor_name), (initial_mod, _) in zip(self._parameters_to_prune, self._initial_parameters_to_prune):
|
|||
|
self._restore_original_weights(mod, initial_mod, tensor_name)
|
|||
|
|
|||
|
def _apply_local_pruning(self, amount: float):
|
|||
|
for module, param in self._parameters_to_prune:
|
|||
|
self.pruning_fn(module, name=param, amount=amount)
|
|||
|
|
|||
|
def _resolve_global_kwargs(self, amount: float):
|
|||
|
kwargs = {}
|
|||
|
self._global_kwargs["amount"] = amount
|
|||
|
params = inspect.signature(self.pruning_fn).parameters
|
|||
|
for p_name in params:
|
|||
|
if p_name != "self":
|
|||
|
param = self._global_kwargs.get(p_name)
|
|||
|
if param is not None:
|
|||
|
kwargs[p_name] = param
|
|||
|
return kwargs
|
|||
|
|
|||
|
def _apply_global_pruning(self, amount: float):
|
|||
|
pytorch_prune.global_unstructured(
|
|||
|
self._parameters_to_prune,
|
|||
|
pruning_method=self.pruning_fn,
|
|||
|
**self._resolve_global_kwargs(amount)
|
|||
|
)
|
|||
|
|
|||
|
def apply_pruning(self, trainer: 'pl.Trainer', pl_module: LightningModule):
|
|||
|
amount = self._resolve_amount(trainer.current_epoch)
|
|||
|
|
|||
|
# the user could control the pruning frequency with amount_fn
|
|||
|
if amount == 0 or amount is None:
|
|||
|
return
|
|||
|
|
|||
|
if self._use_global_unstructured:
|
|||
|
self._apply_global_pruning(amount)
|
|||
|
else:
|
|||
|
self._apply_local_pruning(amount)
|
|||
|
|
|||
|
if self._use_lottery_ticket_hypothesis:
|
|||
|
self.apply_lottery_ticket_hypothesis()
|
|||
|
|
|||
|
def on_before_accelerator_backend_setup(self, trainer, pl_module):
|
|||
|
parameters_to_prune = self.sanitize_parameters_to_prune(
|
|||
|
pl_module, self._parameters_to_prune, parameters=self._parameter_names)
|
|||
|
|
|||
|
self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune)
|
|||
|
|
|||
|
if self._use_lottery_ticket_hypothesis:
|
|||
|
# make a copy of copy of orginal weights.
|
|||
|
self._initial_parameters_to_prune = [(deepcopy(m), n) for m, n in self._parameters_to_prune]
|
|||
|
|
|||
|
def on_epoch_end(self, trainer, pl_module):
|
|||
|
self.apply_pruning(trainer, pl_module)
|
|||
|
|
|||
|
if self.make_pruning_permanent:
|
|||
|
self._make_pruning_permanent()
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _sanitize_parameters_to_prune(p):
|
|||
|
"""
|
|||
|
Check the provide element is a pair with:
|
|||
|
* nn.Module
|
|||
|
* str
|
|||
|
|
|||
|
Example::
|
|||
|
|
|||
|
parameters_to_prune = [
|
|||
|
(model.mlp_1, "weight"),
|
|||
|
(model.mlp_2, "weight")
|
|||
|
]
|
|||
|
"""
|
|||
|
return len(p) == 2 and isinstance(p[0], nn.Module) and isinstance(p[1], str)
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def sanitize_parameters_to_prune(
|
|||
|
pl_module: LightningModule,
|
|||
|
parameters_to_prune: Optional[List[Tuple[nn.Module, str]]],
|
|||
|
parameters: List[str] = ["weight"]
|
|||
|
) -> List:
|
|||
|
"""
|
|||
|
This function is responsible to check provided `parameters_to_prune` and `parameters`.
|
|||
|
If parameters_to_prune is None, parameters_to_prune will be generated from all parameters of the model.
|
|||
|
"""
|
|||
|
|
|||
|
is_parameters_to_prune_none = parameters_to_prune is None
|
|||
|
current_modules = [
|
|||
|
m for m in pl_module.modules()
|
|||
|
if not isinstance(m, (LightningModule, ModuleDict, ModuleList))
|
|||
|
]
|
|||
|
|
|||
|
if is_parameters_to_prune_none:
|
|||
|
parameters_to_prune = []
|
|||
|
for p in parameters:
|
|||
|
for m in current_modules:
|
|||
|
param = getattr(m, p, None)
|
|||
|
if param is not None:
|
|||
|
parameters_to_prune.append((m, p))
|
|||
|
|
|||
|
if isinstance(parameters_to_prune, (tuple, list)) \
|
|||
|
and len(parameters_to_prune) > 0 and not is_parameters_to_prune_none:
|
|||
|
|
|||
|
if all(
|
|||
|
isinstance(p, (list, tuple)) and ModelPruning._sanitize_parameters_to_prune(p)
|
|||
|
for p in parameters_to_prune
|
|||
|
):
|
|||
|
|
|||
|
missing_modules = []
|
|||
|
missing_parameters = []
|
|||
|
|
|||
|
for module, param_name in parameters_to_prune:
|
|||
|
if module not in current_modules:
|
|||
|
missing_modules.append(module)
|
|||
|
continue
|
|||
|
|
|||
|
parameter = getattr(module, param_name)
|
|||
|
|
|||
|
if parameter is None:
|
|||
|
missing_parameters.append(parameter)
|
|||
|
|
|||
|
if len(missing_modules) > 0 or len(missing_parameters) > 0:
|
|||
|
raise MisconfigurationException(
|
|||
|
"Ths provided parameters_to_tune doesn't exist in the model."
|
|||
|
f" Found mismatching modules: {missing_modules} and missing_parameters: {missing_parameters}"
|
|||
|
)
|
|||
|
|
|||
|
else:
|
|||
|
raise MisconfigurationException(
|
|||
|
"The provided parameters_to_prune should either be list of tuple "
|
|||
|
"with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None")
|
|||
|
else:
|
|||
|
if not isinstance(parameters_to_prune, (list, tuple)):
|
|||
|
raise MisconfigurationException(
|
|||
|
"The provided parameters_to_prune should either be list of tuple "
|
|||
|
"with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None")
|
|||
|
|
|||
|
return parameters_to_prune
|