[ModelPruning] Add missing attribute with use_global_unstructured=False and verbose (#6045)

This commit is contained in:
Carlos Mocholí 2021-02-18 11:40:34 +01:00 committed by GitHub
parent 6de8dca31c
commit 38ad9e0764
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 8 deletions

View File

@ -24,9 +24,10 @@ import torch
import torch.nn.utils.prune as pytorch_prune
from torch import nn
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
_PYTORCH_PRUNING_FUNCTIONS = {
@ -152,6 +153,7 @@ class ModelPruning(Callback):
self._parameter_names = parameter_names or self.PARAMETER_NAMES
self._global_kwargs = {}
self._original_layers = None
self._pruning_fn_name = None
for name in self._parameter_names:
if name not in self.PARAMETER_NAMES:
@ -231,8 +233,14 @@ class ModelPruning(Callback):
if self._use_global_unstructured:
pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn]
self._global_kwargs = kwargs
else:
pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn]
# save the function __name__ now because partial does not include it
# and there are issues setting the attribute manually in ddp.
self._pruning_fn_name = pruning_fn.__name__
if self._use_global_unstructured:
return pruning_fn
return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs)
return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs)
@staticmethod
def _wrap_pruning_fn(pruning_fn, **kwargs):
@ -321,15 +329,15 @@ class ModelPruning(Callback):
curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune]
self._log_sparsity_stats(prev_stats, curr_stats, amount=amount)
@rank_zero_only
def _log_sparsity_stats(
self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0
):
total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters())
prev_total_zeros = sum(zeros for zeros, _ in prev)
curr_total_zeros = sum(zeros for zeros, _ in curr)
pruning_fn_name = self.pruning_fn.__name__
rank_zero_info(
f"Applied `{pruning_fn_name}`. Pruned:"
log.info(
f"Applied `{self._pruning_fn_name}`. Pruned:"
f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->"
f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})"
)
@ -337,8 +345,8 @@ class ModelPruning(Callback):
for i, (module, name) in enumerate(self._parameters_to_prune):
prev_mask_zeros, prev_mask_size = prev[i]
curr_mask_zeros, curr_mask_size = curr[i]
rank_zero_info(
f"Applied `{pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:"
log.info(
f"Applied `{self._pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:"
f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->"
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
)

View File

@ -76,7 +76,8 @@ def train_with_pruning_callback(
"pruning_fn": pruning_fn,
"amount": 0.3,
"use_global_unstructured": use_global_unstructured,
"use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis
"use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis,
"verbose": 1,
}
if parameters_to_prune:
pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]