lightning/pytorch_lightning/callbacks/pruning.py

391 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
ModelPruning
^^^^^^^^^^^^
"""
import inspect
from copy import deepcopy
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
from torch import nn
from torch.nn.modules.container import ModuleDict, ModuleList
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import _PYTORCH_PRUNE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _PYTORCH_PRUNE_AVAILABLE:
import torch.nn.utils.prune as pytorch_prune
_PYTORCH_PRUNING_FUNCTIONS = {
"ln_structured": pytorch_prune.ln_structured,
"l1_unstructured": pytorch_prune.l1_unstructured,
"random_structured": pytorch_prune.random_structured,
"random_unstructured": pytorch_prune.random_unstructured,
}
_PYTORCH_PRUNING_METHOD = {
"ln_structured": pytorch_prune.LnStructured,
"l1_unstructured": pytorch_prune.L1Unstructured,
"random_structured": pytorch_prune.RandomStructured,
"random_unstructured": pytorch_prune.RandomUnstructured,
}
class ModelPruning(Callback):
PARAMETER_NAMES = ("weight", "bias")
def __init__(
self,
pruning_fn: Union[Callable, str] = None,
parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None,
parameter_names: List[str] = ["weight"],
use_global_unstructured: bool = True,
amount: Optional[Union[int, float]] = 0.5,
make_pruning_permanent: Optional[bool] = True,
use_lottery_ticket_hypothesis: Optional[bool] = True,
pruning_dim: Optional[int] = None,
pruning_norm: Optional[int] = None,
) -> None:
"""
Pruning Callback relying on PyTorch prune utils.
This callback is responsible to prune networks parameters
during your training.
Find here the PyTorch (Pruning Tutorial)[https://pytorch.org/tutorials/intermediate/pruning_tutorial.html]
.. code-block:: python
parameters_to_prune = [
(model.mlp_1, "weight"),
(model.mlp_2, "weight")
]
trainer = Trainer(
callbacks=[
ModelPruning(
pruning_fn='l1_unstructured',
parameters_to_prune=parameters_to_prune,
amount=0.01,
use_global_unstructured=True,
)
]
)
When `parameters_to_prune` is None, `parameters_to_prune` will contains all parameters from the model.
The user can override `filter_parameters_to_prune` to filter any nn.Module to be pruned.
Args:
pruning_fn: function from torch.nn.utils.prune module
or your based own subclasses from PyTorch ``BasePruningMethod``.
Can be string e.g. `"l1_unstructured"`.
See pytorch docs for more details.
parameters_to_prune: list of strings or list of tuple with
nn.Module and its associated string name parameters.
parameter_names: List of parameter names to be used from nn.Module.
Can either be `weight` or `bias`.
use_global_unstructured: Whether to apply pruning globally on the model.
If parameters_to_prune is provided, global_unstructured will be restricted on them.
amount: quantity of parameters to prune:
- float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune.
- int, it represents the absolute number of parameters to prune.
- Callable, the function will be called on every epoch.
make_pruning_permanent: if True then all
reparametrization pre-hooks and tensors with mask
will be removed on fit end.
use_lottery_ticket_hypothesis: Wether to use algorithm describes in
"The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf)
pruning_dim: if you are using structured pruning method you need
to specify dimension.
pruning_norm: if you are using ln_structured you need to specify norm.
"""
self._use_global_unstructured = use_global_unstructured
self._parameters_to_prune = parameters_to_prune
self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis
self._parameter_names = parameter_names or self.PARAMETER_NAMES
self._global_kwargs = {}
self._initial_parameters_to_prune = None
for param_name in self._parameter_names:
if param_name not in self.PARAMETER_NAMES:
raise MisconfigurationException(
f"The provided parameter_names {param_name} isn't in {self.PARAMETER_NAMES} "
)
if isinstance(pruning_fn, str):
pruning_fn = pruning_fn.lower()
if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS:
raise MisconfigurationException(
f"The provided pruning_fn {pruning_fn} isn't available with "
f"PyTorch build-in {_PYTORCH_PRUNING_FUNCTIONS.keys()} "
)
if "unstructured" not in pruning_fn:
if pruning_dim is None:
raise MisconfigurationException(
"When requesting `structured` pruning, the `pruning_dim` should be provided."
)
if pruning_fn == "ln_structured":
if pruning_norm is None:
raise MisconfigurationException(
"When requesting `ln_structured` pruning, the `pruning_norm` should be provided."
)
pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim, n=pruning_norm)
else:
pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim)
else:
pruning_fn = self._create_pruning_fn(pruning_fn)
else:
bases = getattr(pruning_fn, "__bases__", None)
if bases is None or bases[0] != pytorch_prune.BasePruningMethod:
raise MisconfigurationException(
f'pruning_fn is expected to be the str in {_PYTORCH_PRUNING_FUNCTIONS.keys()} '
f'or a `PyTorch BasePruningMethod`. Found: {pruning_fn}'
)
if not use_global_unstructured:
raise MisconfigurationException(
'`PyTorch BasePruningMethod` is currently support only for `use_global_unstructured=True`. ')
if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured":
raise MisconfigurationException(
'Only "unstructured" PRUNING_TYPE supported for '
f"the `pruning_method`. Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. "
)
self.pruning_fn = pruning_fn
self.make_pruning_permanent = make_pruning_permanent
if not isinstance(amount, (int, float, Callable)):
raise MisconfigurationException(
"amount should be provided and be either an int, a float or Callable function."
)
self.amount = amount
def filter_parameters_to_prune(self, parameters_to_prune: Optional[List[Tuple[nn.Module, str]]]):
"""
This function can be overriden to control which module to prune.
"""
return parameters_to_prune
def _create_pruning_fn(self, pruning_fn: str, *args, **kwargs):
"""
This function takes `pruning_fn`, a function name.
IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod``
ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`.
"""
if self._use_global_unstructured:
pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn]
self._global_kwargs = kwargs
return pruning_fn
else:
return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs)
@staticmethod
def _wrap_pruning_fn(pruning_fn, *args, **kwargs):
return partial(pruning_fn, **kwargs)
def _make_pruning_permanent(self):
for module, param_name in self._parameters_to_prune:
pytorch_prune.remove(module, param_name)
def _resolve_amount(self, current_epoch: int) -> float:
if isinstance(self.amount, Callable):
amount_fn = self.amount
amount = amount_fn(current_epoch)
else:
amount = self.amount
return amount
def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
"""
"The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf) algorithm:
1. Randomly initialize a neural network f(x;θ0)(where θ0 Dθ).
2. Train the network for j iterations, arriving at parameters θj .
3. Prune p% of the parameters in θj , creating a mask m.
4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; m⊙θ0).
This function is responsible of step 4.
"""
trained = getattr(module, tensor_name)
orig = getattr(orig_module, tensor_name)
if trained is None or orig is None:
return
trained.data = orig.data.to(trained.device)
def apply_lottery_ticket_hypothesis(self):
for (mod, tensor_name), (initial_mod, _) in zip(self._parameters_to_prune, self._initial_parameters_to_prune):
self._restore_original_weights(mod, initial_mod, tensor_name)
def _apply_local_pruning(self, amount: float):
for module, param in self._parameters_to_prune:
self.pruning_fn(module, name=param, amount=amount)
def _resolve_global_kwargs(self, amount: float):
kwargs = {}
self._global_kwargs["amount"] = amount
params = inspect.signature(self.pruning_fn).parameters
for p_name in params:
if p_name != "self":
param = self._global_kwargs.get(p_name)
if param is not None:
kwargs[p_name] = param
return kwargs
def _apply_global_pruning(self, amount: float):
pytorch_prune.global_unstructured(
self._parameters_to_prune,
pruning_method=self.pruning_fn,
**self._resolve_global_kwargs(amount)
)
def apply_pruning(self, trainer: 'pl.Trainer', pl_module: LightningModule):
amount = self._resolve_amount(trainer.current_epoch)
# the user could control the pruning frequency with amount_fn
if amount == 0 or amount is None:
return
if self._use_global_unstructured:
self._apply_global_pruning(amount)
else:
self._apply_local_pruning(amount)
if self._use_lottery_ticket_hypothesis:
self.apply_lottery_ticket_hypothesis()
def on_before_accelerator_backend_setup(self, trainer, pl_module):
parameters_to_prune = self.sanitize_parameters_to_prune(
pl_module, self._parameters_to_prune, parameters=self._parameter_names)
self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune)
if self._use_lottery_ticket_hypothesis:
# make a copy of copy of orginal weights.
self._initial_parameters_to_prune = [(deepcopy(m), n) for m, n in self._parameters_to_prune]
def on_epoch_end(self, trainer, pl_module):
self.apply_pruning(trainer, pl_module)
if self.make_pruning_permanent:
self._make_pruning_permanent()
@staticmethod
def _sanitize_parameters_to_prune(p):
"""
Check the provide element is a pair with:
* nn.Module
* str
Example::
parameters_to_prune = [
(model.mlp_1, "weight"),
(model.mlp_2, "weight")
]
"""
return len(p) == 2 and isinstance(p[0], nn.Module) and isinstance(p[1], str)
@staticmethod
def sanitize_parameters_to_prune(
pl_module: LightningModule,
parameters_to_prune: Optional[List[Tuple[nn.Module, str]]],
parameters: List[str] = ["weight"]
) -> List:
"""
This function is responsible to check provided `parameters_to_prune` and `parameters`.
If parameters_to_prune is None, parameters_to_prune will be generated from all parameters of the model.
"""
is_parameters_to_prune_none = parameters_to_prune is None
current_modules = [
m for m in pl_module.modules()
if not isinstance(m, (LightningModule, ModuleDict, ModuleList))
]
if is_parameters_to_prune_none:
parameters_to_prune = []
for p in parameters:
for m in current_modules:
param = getattr(m, p, None)
if param is not None:
parameters_to_prune.append((m, p))
if isinstance(parameters_to_prune, (tuple, list)) \
and len(parameters_to_prune) > 0 and not is_parameters_to_prune_none:
if all(
isinstance(p, (list, tuple)) and ModelPruning._sanitize_parameters_to_prune(p)
for p in parameters_to_prune
):
missing_modules = []
missing_parameters = []
for module, param_name in parameters_to_prune:
if module not in current_modules:
missing_modules.append(module)
continue
parameter = getattr(module, param_name)
if parameter is None:
missing_parameters.append(parameter)
if len(missing_modules) > 0 or len(missing_parameters) > 0:
raise MisconfigurationException(
"Ths provided parameters_to_tune doesn't exist in the model."
f" Found mismatching modules: {missing_modules} and missing_parameters: {missing_parameters}"
)
else:
raise MisconfigurationException(
"The provided parameters_to_prune should either be list of tuple "
"with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None")
else:
if not isinstance(parameters_to_prune, (list, tuple)):
raise MisconfigurationException(
"The provided parameters_to_prune should either be list of tuple "
"with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None")
return parameters_to_prune