2021-01-27 06:00:42 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-05-25 22:57:56 +00:00
|
|
|
import re
|
2021-02-10 15:03:23 +00:00
|
|
|
from collections import OrderedDict
|
|
|
|
from logging import INFO
|
2021-03-09 11:27:15 +00:00
|
|
|
from typing import Union
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2021-02-10 15:03:23 +00:00
|
|
|
import torch.nn.utils.prune as pytorch_prune
|
2021-01-27 06:00:42 +00:00
|
|
|
from torch import nn
|
2021-02-10 15:03:23 +00:00
|
|
|
from torch.nn import Sequential
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-05-25 22:57:56 +00:00
|
|
|
from pytorch_lightning import Trainer
|
2021-03-04 23:10:52 +00:00
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
|
2021-01-27 06:00:42 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-07-20 18:31:49 +00:00
|
|
|
from tests.helpers.boring_model import BoringModel
|
2021-03-02 09:36:01 +00:00
|
|
|
from tests.helpers.runif import RunIf
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
class TestModel(BoringModel):
|
2021-01-27 06:00:42 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
2021-02-10 15:03:23 +00:00
|
|
|
self.layer = Sequential(
|
2021-07-26 11:37:35 +00:00
|
|
|
OrderedDict(
|
|
|
|
[("mlp_1", nn.Linear(32, 32)), ("mlp_2", nn.Linear(32, 32, bias=False)), ("mlp_3", nn.Linear(32, 2))]
|
|
|
|
)
|
2021-02-10 15:03:23 +00:00
|
|
|
)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-03-03 12:29:58 +00:00
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
self.log("test", -batch_idx)
|
|
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
class TestPruningMethod(pytorch_prune.BasePruningMethod):
|
|
|
|
PRUNING_TYPE = "unstructured"
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
def compute_mask(self, _, default_mask):
|
|
|
|
mask = default_mask.clone()
|
|
|
|
# Prune every other entry in a tensor
|
|
|
|
mask.view(-1)[::2] = 0
|
|
|
|
return mask
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def apply(cls, module, name, amount):
|
2021-07-26 12:38:12 +00:00
|
|
|
return super().apply(module, name, amount=amount)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
|
|
|
|
def train_with_pruning_callback(
|
|
|
|
tmpdir,
|
2021-02-10 15:03:23 +00:00
|
|
|
parameters_to_prune=False,
|
|
|
|
use_global_unstructured=False,
|
|
|
|
pruning_fn="l1_unstructured",
|
|
|
|
use_lottery_ticket_hypothesis=False,
|
2021-10-16 15:10:25 +00:00
|
|
|
strategy=None,
|
2022-01-09 08:15:29 +00:00
|
|
|
accelerator="cpu",
|
|
|
|
devices=1,
|
2021-01-27 06:00:42 +00:00
|
|
|
):
|
2021-02-10 15:03:23 +00:00
|
|
|
model = TestModel()
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
# Weights are random. None is 0
|
|
|
|
assert torch.all(model.layer.mlp_2.weight != 0)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
pruning_kwargs = {
|
|
|
|
"pruning_fn": pruning_fn,
|
2021-01-27 06:00:42 +00:00
|
|
|
"amount": 0.3,
|
|
|
|
"use_global_unstructured": use_global_unstructured,
|
2021-02-18 10:40:34 +00:00
|
|
|
"use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis,
|
|
|
|
"verbose": 1,
|
2021-01-27 06:00:42 +00:00
|
|
|
}
|
2021-02-10 15:03:23 +00:00
|
|
|
if parameters_to_prune:
|
|
|
|
pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]
|
|
|
|
else:
|
2021-04-05 23:47:59 +00:00
|
|
|
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
|
|
|
|
pruning_kwargs["parameter_names"] = ["weight"]
|
|
|
|
else:
|
|
|
|
pruning_kwargs["parameter_names"] = ["weight", "bias"]
|
2021-02-10 15:03:23 +00:00
|
|
|
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
|
|
|
|
pruning_kwargs["pruning_dim"] = 0
|
|
|
|
if pruning_fn == "ln_structured":
|
|
|
|
pruning_kwargs["pruning_norm"] = 1
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
# Misconfiguration checks
|
|
|
|
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured") and use_global_unstructured:
|
|
|
|
with pytest.raises(MisconfigurationException, match="is supported with `use_global_unstructured=True`"):
|
|
|
|
ModelPruning(**pruning_kwargs)
|
|
|
|
return
|
|
|
|
if ModelPruning._is_pruning_method(pruning_fn) and not use_global_unstructured:
|
|
|
|
with pytest.raises(MisconfigurationException, match="currently only supported with"):
|
|
|
|
ModelPruning(**pruning_kwargs)
|
|
|
|
return
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
pruning = ModelPruning(**pruning_kwargs)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
2021-09-25 05:53:31 +00:00
|
|
|
enable_progress_bar=False,
|
2021-10-13 11:50:54 +00:00
|
|
|
enable_model_summary=False,
|
2021-10-12 07:55:07 +00:00
|
|
|
enable_checkpointing=False,
|
2021-02-10 15:03:23 +00:00
|
|
|
logger=False,
|
2021-01-27 06:00:42 +00:00
|
|
|
limit_train_batches=10,
|
|
|
|
limit_val_batches=2,
|
|
|
|
max_epochs=10,
|
2021-10-16 15:10:25 +00:00
|
|
|
strategy=strategy,
|
2022-01-09 08:15:29 +00:00
|
|
|
accelerator=accelerator,
|
|
|
|
devices=devices,
|
2021-02-10 15:03:23 +00:00
|
|
|
callbacks=pruning,
|
2021-01-27 06:00:42 +00:00
|
|
|
)
|
|
|
|
trainer.fit(model)
|
2021-02-10 15:03:23 +00:00
|
|
|
trainer.test(model)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
2021-10-16 15:10:25 +00:00
|
|
|
if not strategy:
|
2021-02-10 15:03:23 +00:00
|
|
|
# Check some have been pruned
|
|
|
|
assert torch.any(model.layer.mlp_2.weight == 0)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
def test_pruning_misconfiguration():
|
|
|
|
with pytest.raises(MisconfigurationException, match=r"chocolate isn't in \('weight', 'bias'\)"):
|
|
|
|
ModelPruning(pruning_fn="l1_unstructured", parameter_names=["chocolate"])
|
|
|
|
with pytest.raises(MisconfigurationException, match=r"expected to be a str in \["):
|
2021-08-02 16:05:56 +00:00
|
|
|
ModelPruning(pruning_fn={})
|
2021-02-10 15:03:23 +00:00
|
|
|
with pytest.raises(MisconfigurationException, match="should be provided"):
|
|
|
|
ModelPruning(pruning_fn="random_structured")
|
|
|
|
with pytest.raises(MisconfigurationException, match=r"must be any of \(0, 1, 2\)"):
|
|
|
|
ModelPruning(pruning_fn="l1_unstructured", verbose=3)
|
|
|
|
with pytest.raises(MisconfigurationException, match="requesting `ln_structured` pruning, the `pruning_norm`"):
|
|
|
|
ModelPruning(pruning_fn="ln_structured", pruning_dim=0)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("parameters_to_prune", [False, True])
|
|
|
|
@pytest.mark.parametrize("use_global_unstructured", [False, True])
|
2021-02-10 15:03:23 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod]
|
|
|
|
)
|
|
|
|
@pytest.mark.parametrize("use_lottery_ticket_hypothesis", [False, True])
|
|
|
|
def test_pruning_callback(
|
2021-07-26 11:37:35 +00:00
|
|
|
tmpdir,
|
|
|
|
use_global_unstructured: bool,
|
|
|
|
parameters_to_prune: bool,
|
|
|
|
pruning_fn: Union[str, pytorch_prune.BasePruningMethod],
|
|
|
|
use_lottery_ticket_hypothesis: bool,
|
2021-02-10 15:03:23 +00:00
|
|
|
):
|
2021-01-27 06:00:42 +00:00
|
|
|
train_with_pruning_callback(
|
2021-02-06 12:28:26 +00:00
|
|
|
tmpdir,
|
2021-02-10 15:03:23 +00:00
|
|
|
parameters_to_prune=parameters_to_prune,
|
|
|
|
use_global_unstructured=use_global_unstructured,
|
|
|
|
pruning_fn=pruning_fn,
|
|
|
|
use_lottery_ticket_hypothesis=use_lottery_ticket_hypothesis,
|
2021-02-06 12:28:26 +00:00
|
|
|
)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
|
2022-03-27 21:31:20 +00:00
|
|
|
@RunIf(min_gpus=2, standalone=True)
|
2021-11-17 15:46:14 +00:00
|
|
|
@pytest.mark.parametrize("parameters_to_prune", (False, True))
|
|
|
|
@pytest.mark.parametrize("use_global_unstructured", (False, True))
|
|
|
|
def test_pruning_callback_ddp(tmpdir, parameters_to_prune, use_global_unstructured):
|
2021-01-27 06:00:42 +00:00
|
|
|
train_with_pruning_callback(
|
2021-11-17 15:46:14 +00:00
|
|
|
tmpdir,
|
|
|
|
parameters_to_prune=parameters_to_prune,
|
|
|
|
use_global_unstructured=use_global_unstructured,
|
|
|
|
strategy="ddp",
|
2022-01-09 08:15:29 +00:00
|
|
|
accelerator="gpu",
|
|
|
|
devices=2,
|
2021-07-02 11:00:24 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-03-02 09:36:01 +00:00
|
|
|
@RunIf(min_gpus=2, skip_windows=True)
|
2021-01-27 06:00:42 +00:00
|
|
|
def test_pruning_callback_ddp_spawn(tmpdir):
|
2022-01-09 08:15:29 +00:00
|
|
|
train_with_pruning_callback(
|
|
|
|
tmpdir, use_global_unstructured=True, strategy="ddp_spawn", accelerator="gpu", devices=2
|
|
|
|
)
|
2021-01-27 06:00:42 +00:00
|
|
|
|
|
|
|
|
2022-03-27 21:31:20 +00:00
|
|
|
@RunIf(skip_windows=True)
|
2021-01-27 06:00:42 +00:00
|
|
|
def test_pruning_callback_ddp_cpu(tmpdir):
|
2022-01-18 13:43:01 +00:00
|
|
|
train_with_pruning_callback(tmpdir, parameters_to_prune=True, strategy="ddp_spawn", accelerator="cpu", devices=2)
|
2021-02-10 15:03:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("resample_parameters", (False, True))
|
2021-03-09 11:27:15 +00:00
|
|
|
def test_pruning_lth_callable(tmpdir, resample_parameters: bool):
|
2021-02-10 15:03:23 +00:00
|
|
|
model = TestModel()
|
|
|
|
|
|
|
|
class ModelPruningTestCallback(ModelPruning):
|
|
|
|
lth_calls = 0
|
|
|
|
|
|
|
|
def apply_lottery_ticket_hypothesis(self):
|
|
|
|
super().apply_lottery_ticket_hypothesis()
|
|
|
|
self.lth_calls += 1
|
|
|
|
|
|
|
|
for d in self._original_layers.values():
|
|
|
|
copy, names = d["data"], d["names"]
|
|
|
|
for i, name in names:
|
|
|
|
curr, curr_name = self._parameters_to_prune[i]
|
|
|
|
assert name == curr_name
|
|
|
|
actual, expected = getattr(curr, name).data, getattr(copy, name).data
|
|
|
|
allclose = torch.allclose(actual, expected)
|
|
|
|
assert not allclose if self._resample_parameters else allclose
|
|
|
|
|
|
|
|
pruning = ModelPruningTestCallback(
|
|
|
|
"l1_unstructured", use_lottery_ticket_hypothesis=lambda e: bool(e % 2), resample_parameters=resample_parameters
|
|
|
|
)
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
2021-09-25 05:53:31 +00:00
|
|
|
enable_progress_bar=False,
|
2021-10-13 11:50:54 +00:00
|
|
|
enable_model_summary=False,
|
2021-10-12 07:55:07 +00:00
|
|
|
enable_checkpointing=False,
|
2021-02-10 15:03:23 +00:00
|
|
|
logger=False,
|
|
|
|
limit_train_batches=10,
|
|
|
|
limit_val_batches=2,
|
|
|
|
max_epochs=5,
|
|
|
|
callbacks=pruning,
|
|
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
assert pruning.lth_calls == trainer.max_epochs // 2
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("make_pruning_permanent", (False, True))
|
2021-03-09 11:27:15 +00:00
|
|
|
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool):
|
2021-02-10 15:03:23 +00:00
|
|
|
model = TestModel()
|
|
|
|
pruning_kwargs = {
|
2021-07-26 11:37:35 +00:00
|
|
|
"parameters_to_prune": [(model.layer.mlp_1, "weight"), (model.layer.mlp_3, "weight")],
|
|
|
|
"verbose": 2,
|
|
|
|
"make_pruning_permanent": make_pruning_permanent,
|
2021-02-10 15:03:23 +00:00
|
|
|
}
|
|
|
|
p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs)
|
|
|
|
p2 = ModelPruning("random_unstructured", amount=0.25, apply_pruning=lambda e: e % 2, **pruning_kwargs)
|
2021-03-09 11:27:15 +00:00
|
|
|
|
2021-02-10 15:03:23 +00:00
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
2021-09-25 05:53:31 +00:00
|
|
|
enable_progress_bar=False,
|
2021-10-13 11:50:54 +00:00
|
|
|
enable_model_summary=False,
|
2021-10-12 07:55:07 +00:00
|
|
|
enable_checkpointing=False,
|
2021-02-10 15:03:23 +00:00
|
|
|
logger=False,
|
|
|
|
limit_train_batches=10,
|
|
|
|
limit_val_batches=2,
|
|
|
|
max_epochs=3,
|
|
|
|
callbacks=[p1, p2],
|
|
|
|
)
|
|
|
|
with caplog.at_level(INFO):
|
|
|
|
trainer.fit(model)
|
|
|
|
|
2021-03-03 12:29:58 +00:00
|
|
|
actual = [m.strip() for m in caplog.messages]
|
|
|
|
actual = [m for m in actual if m.startswith("Applied")]
|
2021-05-25 22:57:56 +00:00
|
|
|
percentage = r"\(\d+(?:\.\d+)?%\)"
|
|
|
|
expected = [
|
|
|
|
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
|
|
|
|
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
|
|
|
|
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: 0 \(0.00%\) -> \d+ {percentage}", # noqa: E501
|
|
|
|
rf"Applied `RandomUnstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
|
|
|
|
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
|
|
|
|
rf"Applied `RandomUnstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.25. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
|
|
|
|
rf"Applied `L1Unstructured`. Pruned: \d+\/1122 {percentage} -> \d+\/1122 {percentage}",
|
|
|
|
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=32, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
|
|
|
|
rf"Applied `L1Unstructured` to `Linear\(in_features=32, out_features=2, bias=True\).weight` with amount=0.5. Pruned: \d+ {percentage} -> \d+ {percentage}", # noqa: E501
|
2021-02-10 15:03:23 +00:00
|
|
|
]
|
2021-05-25 22:57:56 +00:00
|
|
|
expected = [re.compile(s) for s in expected]
|
|
|
|
assert all(regex.match(s) for s, regex in zip(actual, expected))
|
2021-02-10 15:03:23 +00:00
|
|
|
|
|
|
|
filepath = str(tmpdir / "foo.ckpt")
|
|
|
|
trainer.save_checkpoint(filepath)
|
|
|
|
|
|
|
|
model.load_from_checkpoint(filepath, strict=False)
|
|
|
|
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
|
|
|
|
assert not has_pruning if make_pruning_permanent else has_pruning
|
2021-03-03 12:29:58 +00:00
|
|
|
|
|
|
|
|
2021-07-13 14:47:59 +00:00
|
|
|
@pytest.mark.parametrize("prune_on_train_epoch_end", (False, True))
|
|
|
|
@pytest.mark.parametrize("save_on_train_epoch_end", (False, True))
|
|
|
|
def test_permanent_when_model_is_saved_multiple_times(
|
|
|
|
tmpdir, caplog, prune_on_train_epoch_end, save_on_train_epoch_end
|
|
|
|
):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not
|
|
|
|
the trained model if we want to continue with the same pruning buffers."""
|
2021-07-13 14:47:59 +00:00
|
|
|
if prune_on_train_epoch_end and save_on_train_epoch_end:
|
|
|
|
pytest.xfail(
|
|
|
|
"Pruning sets the `grad_fn` of the parameters so we can't save"
|
|
|
|
" right after as pruning has not been made permanent"
|
|
|
|
)
|
2021-03-03 12:29:58 +00:00
|
|
|
|
|
|
|
class TestPruning(ModelPruning):
|
|
|
|
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
2021-07-13 14:47:59 +00:00
|
|
|
had_buffers = hasattr(pl_module.layer.mlp_3, "weight_orig")
|
2021-03-03 12:29:58 +00:00
|
|
|
super().on_save_checkpoint(trainer, pl_module, checkpoint)
|
2021-07-13 14:47:59 +00:00
|
|
|
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
|
|
|
|
if had_buffers:
|
2021-05-25 22:57:56 +00:00
|
|
|
assert hasattr(pl_module.layer.mlp_3, "weight_orig")
|
2021-03-03 12:29:58 +00:00
|
|
|
|
|
|
|
model = TestModel()
|
|
|
|
pruning_callback = TestPruning(
|
|
|
|
"random_unstructured",
|
|
|
|
parameters_to_prune=[(model.layer.mlp_3, "weight")],
|
|
|
|
verbose=1,
|
2021-05-25 22:57:56 +00:00
|
|
|
make_pruning_permanent=True,
|
2021-07-13 14:47:59 +00:00
|
|
|
prune_on_train_epoch_end=prune_on_train_epoch_end,
|
|
|
|
)
|
|
|
|
ckpt_callback = ModelCheckpoint(
|
|
|
|
monitor="test", save_top_k=2, save_last=True, save_on_train_epoch_end=save_on_train_epoch_end
|
2021-03-03 12:29:58 +00:00
|
|
|
)
|
2021-09-25 05:53:31 +00:00
|
|
|
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, enable_progress_bar=False)
|
2021-03-03 12:29:58 +00:00
|
|
|
with caplog.at_level(INFO):
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
actual = [m.strip() for m in caplog.messages]
|
|
|
|
actual = [m for m in actual if m.startswith("Applied")]
|
2021-05-25 22:57:56 +00:00
|
|
|
percentage = r"\(\d+(?:\.\d+)?%\)"
|
|
|
|
expected = [
|
|
|
|
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
|
|
|
|
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
|
|
|
|
rf"Applied `RandomUnstructured`. Pruned: \d+\/66 {percentage} -> \d+\/66 {percentage}",
|
2021-03-03 12:29:58 +00:00
|
|
|
]
|
2021-05-25 22:57:56 +00:00
|
|
|
expected = [re.compile(s) for s in expected]
|
|
|
|
assert all(regex.match(s) for s, regex in zip(actual, expected))
|
2021-03-03 12:29:58 +00:00
|
|
|
|
|
|
|
# removed on_train_end
|
|
|
|
assert not hasattr(model.layer.mlp_3, "weight_orig")
|
|
|
|
|
|
|
|
model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
|
|
|
|
assert not hasattr(model.layer.mlp_3, "weight_orig")
|
|
|
|
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
|
|
|
|
assert not hasattr(model.layer.mlp_3, "weight_orig")
|