diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index d3d280dbaa..ee130a700a 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -24,9 +24,10 @@ import torch import torch.nn.utils.prune as pytorch_prune from torch import nn +from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException _PYTORCH_PRUNING_FUNCTIONS = { @@ -152,6 +153,7 @@ class ModelPruning(Callback): self._parameter_names = parameter_names or self.PARAMETER_NAMES self._global_kwargs = {} self._original_layers = None + self._pruning_fn_name = None for name in self._parameter_names: if name not in self.PARAMETER_NAMES: @@ -231,8 +233,14 @@ class ModelPruning(Callback): if self._use_global_unstructured: pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn] self._global_kwargs = kwargs + else: + pruning_fn = _PYTORCH_PRUNING_FUNCTIONS[pruning_fn] + # save the function __name__ now because partial does not include it + # and there are issues setting the attribute manually in ddp. + self._pruning_fn_name = pruning_fn.__name__ + if self._use_global_unstructured: return pruning_fn - return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs) + return ModelPruning._wrap_pruning_fn(pruning_fn, **kwargs) @staticmethod def _wrap_pruning_fn(pruning_fn, **kwargs): @@ -321,15 +329,15 @@ class ModelPruning(Callback): curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] self._log_sparsity_stats(prev_stats, curr_stats, amount=amount) + @rank_zero_only def _log_sparsity_stats( self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 ): total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) prev_total_zeros = sum(zeros for zeros, _ in prev) curr_total_zeros = sum(zeros for zeros, _ in curr) - pruning_fn_name = self.pruning_fn.__name__ - rank_zero_info( - f"Applied `{pruning_fn_name}`. Pruned:" + log.info( + f"Applied `{self._pruning_fn_name}`. Pruned:" f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->" f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})" ) @@ -337,8 +345,8 @@ class ModelPruning(Callback): for i, (module, name) in enumerate(self._parameters_to_prune): prev_mask_zeros, prev_mask_size = prev[i] curr_mask_zeros, curr_mask_size = curr[i] - rank_zero_info( - f"Applied `{pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:" + log.info( + f"Applied `{self._pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:" f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->" f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 34fbd57203..62b0d3a8f3 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -76,7 +76,8 @@ def train_with_pruning_callback( "pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured, - "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis + "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis, + "verbose": 1, } if parameters_to_prune: pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]