233 lines
7.9 KiB
Python
233 lines
7.9 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import platform
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.utilities import _PYTORCH_PRUNE_AVAILABLE
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests.base import BoringModel
|
|
|
|
if _PYTORCH_PRUNE_AVAILABLE:
|
|
import torch.nn.utils.prune as pytorch_prune
|
|
|
|
from pytorch_lightning.callbacks import ModelPruning
|
|
|
|
|
|
class PruningModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer = nn.ModuleDict()
|
|
|
|
self.layer["mlp_1"] = nn.Linear(32, 32)
|
|
self.layer["mlp_2"] = nn.Linear(32, 32)
|
|
self.layer["mlp_3"] = nn.Linear(32, 32)
|
|
self.layer["mlp_4"] = nn.Linear(32, 32)
|
|
self.layer["mlp_5"] = nn.Linear(32, 2)
|
|
|
|
def forward(self, x):
|
|
m = self.layer
|
|
x = m["mlp_1"](x)
|
|
x = m["mlp_2"](x)
|
|
x = m["mlp_3"](x)
|
|
x = m["mlp_4"](x)
|
|
return m["mlp_5"](x)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
output = self.forward(batch)
|
|
loss = self.loss(batch, output)
|
|
return {"loss": loss}
|
|
|
|
|
|
def train_with_pruning_callback(
|
|
tmpdir,
|
|
parameters_to_prune,
|
|
use_global_unstructured,
|
|
accelerator=None,
|
|
gpus=None,
|
|
num_processes=None,
|
|
use_custom_pruning_fn=False,
|
|
):
|
|
# Skipped as currently not supported.
|
|
# Todo: add support for custom pruning_fn function.
|
|
if not use_global_unstructured and use_custom_pruning_fn:
|
|
return
|
|
|
|
model = PruningModel()
|
|
model.validation_step = None
|
|
model.test_step = None
|
|
|
|
if parameters_to_prune:
|
|
parameters_to_prune = [
|
|
(model.layer["mlp_1"], "weight"),
|
|
(model.layer["mlp_2"], "weight")
|
|
]
|
|
|
|
else:
|
|
parameters_to_prune = None
|
|
|
|
assert torch.sum(model.layer["mlp_2"].weight == 0) == 0
|
|
|
|
class TestPruningMethod(pytorch_prune.BasePruningMethod):
|
|
"""Prune every other entry in a tensor
|
|
"""
|
|
PRUNING_TYPE = 'unstructured'
|
|
|
|
def compute_mask(self, t, default_mask):
|
|
mask = default_mask.clone()
|
|
mask.view(-1)[::2] = 0
|
|
return mask
|
|
|
|
@classmethod
|
|
def apply(cls, module, name, amount):
|
|
r"""Adds the forward pre-hook that enables pruning on the fly and
|
|
the reparametrization of a tensor in terms of the original tensor
|
|
and the pruning mask.
|
|
|
|
Args:
|
|
module (nn.Module): module containing the tensor to prune
|
|
name (str): parameter name within ``module`` on which pruning
|
|
will act.
|
|
amount (int or float): quantity of parameters to prune.
|
|
If ``float``, should be between 0.0 and 1.0 and represent the
|
|
fraction of parameters to prune. If ``int``, it represents the
|
|
absolute number of parameters to prune.
|
|
"""
|
|
return super(TestPruningMethod, cls).apply(
|
|
module, name, amount=amount
|
|
)
|
|
|
|
custom_pruning_fn = TestPruningMethod
|
|
|
|
pruning_funcs_structured = [
|
|
"ln_structured",
|
|
"random_structured",
|
|
]
|
|
|
|
pruning_funcs_unstructured = [
|
|
"l1_unstructured",
|
|
"random_unstructured",
|
|
]
|
|
|
|
if use_global_unstructured:
|
|
pruning_list = pruning_funcs_unstructured
|
|
else:
|
|
pruning_list = pruning_funcs_unstructured + pruning_funcs_structured
|
|
|
|
rand_idx = np.random.randint(len(pruning_list))
|
|
pruning_fn = pruning_list[rand_idx]
|
|
|
|
model_pruning_args = {
|
|
"pruning_fn": custom_pruning_fn if use_custom_pruning_fn else pruning_fn ,
|
|
"parameters_to_prune": parameters_to_prune,
|
|
"amount": 0.3,
|
|
"use_global_unstructured": use_global_unstructured,
|
|
}
|
|
|
|
if "unstructured" not in pruning_fn:
|
|
model_pruning_args["pruning_dim"] = 0
|
|
|
|
if pruning_fn == "ln_structured":
|
|
model_pruning_args["pruning_norm"] = 1
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
limit_train_batches=10,
|
|
limit_val_batches=2,
|
|
max_epochs=10,
|
|
accelerator=accelerator,
|
|
gpus=gpus,
|
|
num_processes=num_processes,
|
|
callbacks=ModelPruning(**model_pruning_args)
|
|
)
|
|
trainer.fit(model)
|
|
_ = trainer.test(model)
|
|
|
|
if accelerator is None:
|
|
assert torch.sum(model.layer["mlp_2"].weight == 0) > 0
|
|
|
|
|
|
def test_with_pruning_callback_misconfiguration(tmpdir):
|
|
model_pruning_args = {
|
|
"parameter_names": ["chocolat"],
|
|
}
|
|
|
|
with pytest.raises(MisconfigurationException, match='provided parameter_names'):
|
|
_ = ModelPruning(**model_pruning_args)
|
|
|
|
model_pruning_args = {
|
|
"parameter_names": ["weight"],
|
|
"pruning_fn": model_pruning_args
|
|
}
|
|
|
|
with pytest.raises(MisconfigurationException, match='pruning_fn is expected to be the str in'):
|
|
_ = ModelPruning(**model_pruning_args)
|
|
|
|
model_pruning_args = {
|
|
"parameter_names": ["weight"],
|
|
"pruning_fn": "random_structured"
|
|
}
|
|
|
|
with pytest.raises(MisconfigurationException, match='should be provided'):
|
|
_ = ModelPruning(**model_pruning_args)
|
|
|
|
model_pruning_args = {
|
|
"parameter_names": ["weight"],
|
|
"pruning_fn": "ln_structured",
|
|
"pruning_dim": 0
|
|
}
|
|
|
|
with pytest.raises(MisconfigurationException, match='requesting `ln_structured` pruning, the `pruning_norm`'):
|
|
_ = ModelPruning(**model_pruning_args)
|
|
|
|
|
|
@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ")
|
|
@pytest.mark.parametrize("parameters_to_prune", [False, True])
|
|
@pytest.mark.parametrize("use_global_unstructured", [False, True])
|
|
@pytest.mark.parametrize("use_custom_pruning_fn", [False, True])
|
|
def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, use_custom_pruning_fn):
|
|
train_with_pruning_callback(
|
|
tmpdir, parameters_to_prune, use_global_unstructured,
|
|
accelerator=None, gpus=None, num_processes=1, use_custom_pruning_fn=use_custom_pruning_fn)
|
|
|
|
|
|
@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ")
|
|
@pytest.mark.parametrize("parameters_to_prune", [False, True])
|
|
@pytest.mark.parametrize("use_global_unstructured", [False, True])
|
|
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
|
reason="test should be run outside of pytest")
|
|
def test_pruning_callback_ddp(tmpdir, use_global_unstructured, parameters_to_prune):
|
|
train_with_pruning_callback(
|
|
tmpdir, parameters_to_prune, use_global_unstructured,
|
|
accelerator="ddp", gpus=2, num_processes=0)
|
|
|
|
|
|
@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ")
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
|
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
|
def test_pruning_callback_ddp_spawn(tmpdir):
|
|
train_with_pruning_callback(tmpdir, False, True, accelerator="ddp_spawn", gpus=2, num_processes=None)
|
|
|
|
|
|
@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ")
|
|
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
|
def test_pruning_callback_ddp_cpu(tmpdir):
|
|
train_with_pruning_callback(tmpdir, True, False, accelerator="ddp_cpu", gpus=None, num_processes=2)
|