[ModelPruning] Add missing attribute with use_global_unstructured=False and verbose (#6045)
This commit is contained in:
parent
6de8dca31c
commit
38ad9e0764
|
@ -24,9 +24,10 @@ import torch
|
|||
import torch.nn.utils.prune as pytorch_prune
|
||||
from torch import nn
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
_PYTORCH_PRUNING_FUNCTIONS = {
|
||||
|
@ -152,6 +153,7 @@ class ModelPruning(Callback):
|
|||
self._parameter_names = parameter_names or self.PARAMETER_NAMES
|
||||
self._global_kwargs = {}
|
||||
self._original_layers = None
|
||||
self._pruning_fn_name = None
|
||||
|
||||
for name in self._parameter_names:
|
||||
if name not in self.PARAMETER_NAMES:
|
||||
|
@ -231,8 +233,14 @@ class ModelPruning(Callback):
|
|||
if self._use_global_unstructured:
|
||||
pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn]
|
||||
self._global_kwargs = kwargs
|
||||
else:
|
||||
pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
|
||||
# save the function __name__ now because partial does not include it
|
||||
# and there are issues setting the attribute manually in ddp.
|
||||
self._pruning_fn_name = pruning_fn.__name__
|
||||
if self._use_global_unstructured:
|
||||
return pruning_fn
|
||||
return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs)
|
||||
return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_pruning_fn(pruning_fn, **kwargs):
|
||||
|
@ -321,15 +329,15 @@ class ModelPruning(Callback):
|
|||
curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
|
||||
self._log_sparsity_stats(prev_stats, curr_stats, amount=amount)
|
||||
|
||||
@rank_zero_only
|
||||
def _log_sparsity_stats(
|
||||
self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0
|
||||
):
|
||||
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
|
||||
prev_total_zeros = sum(zeros for zeros, _ in prev)
|
||||
curr_total_zeros = sum(zeros for zeros, _ in curr)
|
||||
pruning_fn_name = self.pruning_fn.__name__
|
||||
rank_zero_info(
|
||||
f"Applied `{pruning_fn_name}`. Pruned:"
|
||||
log.info(
|
||||
f"Applied `{self._pruning_fn_name}`. Pruned:"
|
||||
f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->"
|
||||
f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})"
|
||||
)
|
||||
|
@ -337,8 +345,8 @@ class ModelPruning(Callback):
|
|||
for i, (module, name) in enumerate(self._parameters_to_prune):
|
||||
prev_mask_zeros, prev_mask_size = prev[i]
|
||||
curr_mask_zeros, curr_mask_size = curr[i]
|
||||
rank_zero_info(
|
||||
f"Applied `{pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:"
|
||||
log.info(
|
||||
f"Applied `{self._pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:"
|
||||
f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->"
|
||||
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
|
||||
)
|
||||
|
|
|
@ -76,7 +76,8 @@ def train_with_pruning_callback(
|
|||
"pruning_fn": pruning_fn,
|
||||
"amount": 0.3,
|
||||
"use_global_unstructured": use_global_unstructured,
|
||||
"use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis
|
||||
"use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis,
|
||||
"verbose": 1,
|
||||
}
|
||||
if parameters_to_prune:
|
||||
pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]
|
||||
|
|
Loading…
Reference in New Issue