lightning/pytorch_lightning/callbacks/pruning.py

391 lines
15 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
ModelPruning
^^^^^^^^^^^^
"""
import inspect
from copy import deepcopy
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
from torch import nn
from torch.nn.modules.container import ModuleDict, ModuleList
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import _PYTORCH_PRUNE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _PYTORCH_PRUNE_AVAILABLE:
import torch.nn.utils.prune as pytorch_prune
_PYTORCH_PRUNING_FUNCTIONS = {
"ln_structured": pytorch_prune.ln_structured,
"l1_unstructured": pytorch_prune.l1_unstructured,
"random_structured": pytorch_prune.random_structured,
"random_unstructured": pytorch_prune.random_unstructured,
}
_PYTORCH_PRUNING_METHOD = {
"ln_structured": pytorch_prune.LnStructured,
"l1_unstructured": pytorch_prune.L1Unstructured,
"random_structured": pytorch_prune.RandomStructured,
"random_unstructured": pytorch_prune.RandomUnstructured,
}
class ModelPruning(Callback):
PARAMETER_NAMES = ("weight", "bias")
def __init__(
self,
pruning_fn: Union[Callable, str] = None,
parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None,
parameter_names: List[str] = ["weight"],
use_global_unstructured: bool = True,
amount: Optional[Union[int, float]] = 0.5,
make_pruning_permanent: Optional[bool] = True,
use_lottery_ticket_hypothesis: Optional[bool] = True,
pruning_dim: Optional[int] = None,
pruning_norm: Optional[int] = None,
) -> None:
"""
Pruning Callback relying on PyTorch prune utils.
This callback is responsible to prune networks parameters
during your training.
Find here the PyTorch (Pruning Tutorial)[https://pytorch.org/tutorials/intermediate/pruning_tutorial.html]
.. code-block:: python
parameters_to_prune = [
(model.mlp_1, "weight"),
(model.mlp_2, "weight")
]
trainer = Trainer(
callbacks=[
ModelPruning(
pruning_fn='l1_unstructured',
parameters_to_prune=parameters_to_prune,
amount=0.01,
use_global_unstructured=True,
)
]
)
When `parameters_to_prune` is None, `parameters_to_prune` will contains all parameters from the model.
The user can override `filter_parameters_to_prune` to filter any nn.Module to be pruned.
Args:
pruning_fn: function from torch.nn.utils.prune module
or your based own subclasses from PyTorch ``BasePruningMethod``.
Can be string e.g. `"l1_unstructured"`.
See pytorch docs for more details.
parameters_to_prune: list of strings or list of tuple with
nn.Module and its associated string name parameters.
parameter_names: List of parameter names to be used from nn.Module.
Can either be `weight` or `bias`.
use_global_unstructured: Whether to apply pruning globally on the model.
If parameters_to_prune is provided, global_unstructured will be restricted on them.
amount: quantity of parameters to prune:
- float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune.
- int, it represents the absolute number of parameters to prune.
- Callable, the function will be called on every epoch.
make_pruning_permanent: if True then all
reparametrization pre-hooks and tensors with mask
will be removed on fit end.
use_lottery_ticket_hypothesis: Wether to use algorithm describes in
"The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf)
pruning_dim: if you are using structured pruning method you need
to specify dimension.
pruning_norm: if you are using ln_structured you need to specify norm.
"""
self._use_global_unstructured = use_global_unstructured
self._parameters_to_prune = parameters_to_prune
self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
self._parameter_names = parameter_names or self.PARAMETER_NAMES
self._global_kwargs = {}
self._initial_parameters_to_prune = None
for param_name in self._parameter_names:
if param_name not in self.PARAMETER_NAMES:
raise MisconfigurationException(
f"The provided parameter_names {param_name} isn't in {self.PARAMETER_NAMES} "
)
if isinstance(pruning_fn, str):
pruning_fn = pruning_fn.lower()
if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS:
raise MisconfigurationException(
f"The provided pruning_fn {pruning_fn} isn't available with "
f"PyTorch build-in {_PYTORCH_PRUNING_FUNCTIONS.keys()} "
)
if "unstructured" not in pruning_fn:
if pruning_dim is None:
raise MisconfigurationException(
"When requesting `structured` pruning, the `pruning_dim` should be provided."
)
if pruning_fn == "ln_structured":
if pruning_norm is None:
raise MisconfigurationException(
"When requesting `ln_structured` pruning, the `pruning_norm` should be provided."
)
pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim, n=pruning_norm)
else:
pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim)
else:
pruning_fn = self._create_pruning_fn(pruning_fn)
else:
bases = getattr(pruning_fn, "__bases__", None)
if bases is None or bases[0] != pytorch_prune.BasePruningMethod:
raise MisconfigurationException(
f'pruning_fn is expected to be the str in {_PYTORCH_PRUNING_FUNCTIONS.keys()} '
f'or a `PyTorch BasePruningMethod`. Found: {pruning_fn}'
)
if not use_global_unstructured:
raise MisconfigurationException(
'`PyTorch BasePruningMethod` is currently support only for `use_global_unstructured=True`. ')
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured":
raise MisconfigurationException(
'Only "unstructured" PRUNING_TYPE supported for '
f"the `pruning_method`. Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
)
self.pruning_fn = pruning_fn
self.make_pruning_permanent = make_pruning_permanent
if not isinstance(amount, (int, float, Callable)):
raise MisconfigurationException(
"amount should be provided and be either an int, a float or Callable function."
)
self.amount = amount
def filter_parameters_to_prune(self, parameters_to_prune: Optional[List[Tuple[nn.Module, str]]]):
"""
This function can be overriden to control which module to prune.
"""
return parameters_to_prune
def _create_pruning_fn(self, pruning_fn: str, *args, **kwargs):
"""
This function takes `pruning_fn`, a function name.
IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod``
ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
"""
if self._use_global_unstructured:
pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn]
self._global_kwargs = kwargs
return pruning_fn
else:
return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs)
@staticmethod
def _wrap_pruning_fn(pruning_fn, *args, **kwargs):
return partial(pruning_fn, **kwargs)
def _make_pruning_permanent(self):
for module, param_name in self._parameters_to_prune:
pytorch_prune.remove(module, param_name)
def _resolve_amount(self, current_epoch: int) -> float:
if isinstance(self.amount, Callable):
amount_fn = self.amount
amount = amount_fn(current_epoch)
else:
amount = self.amount
return amount
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
"""
"The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf) algorithm:
1. Randomly initialize a neural network f(x;θ0)(where θ0 ).
2. Train the network for j iterations, arriving at parameters θj .
3. Prune p% of the parameters in θj , creating a mask m.
4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; mθ0).
This function is responsible of step 4.
"""
trained = getattr(module, tensor_name)
orig = getattr(orig_module, tensor_name)
if trained is None or orig is None:
return
trained.data = orig.data.to(trained.device)
def apply_lottery_ticket_hypothesis(self):
for (mod, tensor_name), (initial_mod, _) in zip(self._parameters_to_prune, self._initial_parameters_to_prune):
self._restore_original_weights(mod, initial_mod, tensor_name)
def _apply_local_pruning(self, amount: float):
for module, param in self._parameters_to_prune:
self.pruning_fn(module, name=param, amount=amount)
def _resolve_global_kwargs(self, amount: float):
kwargs = {}
self._global_kwargs["amount"] = amount
params = inspect.signature(self.pruning_fn).parameters
for p_name in params:
if p_name != "self":
param = self._global_kwargs.get(p_name)
if param is not None:
kwargs[p_name] = param
return kwargs
def _apply_global_pruning(self, amount: float):
pytorch_prune.global_unstructured(
self._parameters_to_prune,
pruning_method=self.pruning_fn,
**self._resolve_global_kwargs(amount)
)
def apply_pruning(self, trainer: 'pl.Trainer', pl_module: LightningModule):
amount = self._resolve_amount(trainer.current_epoch)
# the user could control the pruning frequency with amount_fn
if amount == 0 or amount is None:
return
if self._use_global_unstructured:
self._apply_global_pruning(amount)
else:
self._apply_local_pruning(amount)
if self._use_lottery_ticket_hypothesis:
self.apply_lottery_ticket_hypothesis()
def on_before_accelerator_backend_setup(self, trainer, pl_module):
parameters_to_prune = self.sanitize_parameters_to_prune(
pl_module, self._parameters_to_prune, parameters=self._parameter_names)
self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune)
if self._use_lottery_ticket_hypothesis:
# make a copy of copy of orginal weights.
self._initial_parameters_to_prune = [(deepcopy(m), n) for m, n in self._parameters_to_prune]
def on_epoch_end(self, trainer, pl_module):
self.apply_pruning(trainer, pl_module)
if self.make_pruning_permanent:
self._make_pruning_permanent()
@staticmethod
def _sanitize_parameters_to_prune(p):
"""
Check the provide element is a pair with:
* nn.Module
* str
Example::
parameters_to_prune = [
(model.mlp_1, "weight"),
(model.mlp_2, "weight")
]
"""
return len(p) == 2 and isinstance(p[0], nn.Module) and isinstance(p[1], str)
@staticmethod
def sanitize_parameters_to_prune(
pl_module: LightningModule,
parameters_to_prune: Optional[List[Tuple[nn.Module, str]]],
parameters: List[str] = ["weight"]
) -> List:
"""
This function is responsible to check provided `parameters_to_prune` and `parameters`.
If parameters_to_prune is None, parameters_to_prune will be generated from all parameters of the model.
"""
is_parameters_to_prune_none = parameters_to_prune is None
current_modules = [
m for m in pl_module.modules()
if not isinstance(m, (LightningModule, ModuleDict, ModuleList))
]
if is_parameters_to_prune_none:
parameters_to_prune = []
for p in parameters:
for m in current_modules:
param = getattr(m, p, None)
if param is not None:
parameters_to_prune.append((m, p))
if isinstance(parameters_to_prune, (tuple, list)) \
and len(parameters_to_prune) > 0 and not is_parameters_to_prune_none:
if all(
isinstance(p, (list, tuple)) and ModelPruning._sanitize_parameters_to_prune(p)
for p in parameters_to_prune
):
missing_modules = []
missing_parameters = []
for module, param_name in parameters_to_prune:
if module not in current_modules:
missing_modules.append(module)
continue
parameter = getattr(module, param_name)
if parameter is None:
missing_parameters.append(parameter)
if len(missing_modules) > 0 or len(missing_parameters) > 0:
raise MisconfigurationException(
"Ths provided parameters_to_tune doesn't exist in the model."
f" Found mismatching modules: {missing_modules} and missing_parameters: {missing_parameters}"
)
else:
raise MisconfigurationException(
"The provided parameters_to_prune should either be list of tuple "
"with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None")
else:
if not isinstance(parameters_to_prune, (list, tuple)):
raise MisconfigurationException(
"The provided parameters_to_prune should either be list of tuple "
"with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None")
return parameters_to_prune