2020-10-13 11:18:07 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2020-05-28 02:44:46 +00:00
|
|
|
import pytest
|
2021-06-07 10:17:11 +00:00
|
|
|
import torch
|
2021-01-05 06:51:22 +00:00
|
|
|
from torch import optim
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2021-02-08 10:52:02 +00:00
|
|
|
import tests.helpers.utils as tutils
|
2020-05-28 02:44:46 +00:00
|
|
|
from pytorch_lightning import Trainer
|
2020-09-03 18:17:15 +00:00
|
|
|
from pytorch_lightning.callbacks import LearningRateMonitor
|
2021-06-07 10:17:11 +00:00
|
|
|
from pytorch_lightning.callbacks.base import Callback
|
|
|
|
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning
|
2020-09-03 18:17:15 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-02-09 10:10:52 +00:00
|
|
|
from tests.helpers import BoringModel
|
2021-02-16 19:59:57 +00:00
|
|
|
from tests.helpers.datamodules import ClassifDataModule
|
|
|
|
from tests.helpers.simple_models import ClassificationModel
|
2020-05-28 02:44:46 +00:00
|
|
|
|
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
def test_lr_monitor_single_lr(tmpdir):
|
2021-07-26 11:37:35 +00:00
|
|
|
"""Test that learning rates are extracted and logged for single lr scheduler."""
|
2020-05-28 02:44:46 +00:00
|
|
|
tutils.reset_seed()
|
|
|
|
|
2021-02-16 19:59:57 +00:00
|
|
|
model = BoringModel()
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
lr_monitor = LearningRateMonitor()
|
2020-05-28 02:44:46 +00:00
|
|
|
trainer = Trainer(
|
2021-07-26 11:37:35 +00:00
|
|
|
default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]
|
2020-05-28 02:44:46 +00:00
|
|
|
)
|
2021-01-12 00:36:48 +00:00
|
|
|
trainer.fit(model)
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs, "No learning rates logged"
|
|
|
|
assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default"
|
|
|
|
assert len(lr_monitor.lrs) == len(
|
|
|
|
trainer.lr_schedulers
|
|
|
|
), "Number of learning rates logged does not match number of lr schedulers"
|
|
|
|
assert (
|
|
|
|
lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["lr-SGD"]
|
|
|
|
), "Names of learning rates not set correctly"
|
2020-05-28 02:44:46 +00:00
|
|
|
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("opt", ["SGD", "Adam"])
|
2021-03-09 11:27:15 +00:00
|
|
|
def test_lr_monitor_single_lr_with_momentum(tmpdir, opt: str):
|
|
|
|
"""Test that learning rates and momentum are extracted and logged for single lr scheduler."""
|
2021-02-06 12:28:26 +00:00
|
|
|
|
2021-01-05 06:51:22 +00:00
|
|
|
class LogMomentumModel(BoringModel):
|
|
|
|
def __init__(self, opt):
|
|
|
|
super().__init__()
|
|
|
|
self.opt = opt
|
2020-10-28 16:26:58 +00:00
|
|
|
|
2021-01-05 06:51:22 +00:00
|
|
|
def configure_optimizers(self):
|
2021-07-26 11:37:35 +00:00
|
|
|
if self.opt == "SGD":
|
|
|
|
opt_kwargs = {"momentum": 0.9}
|
|
|
|
elif self.opt == "Adam":
|
|
|
|
opt_kwargs = {"betas": (0.9, 0.999)}
|
2021-01-05 06:51:22 +00:00
|
|
|
|
|
|
|
optimizer = getattr(optim, self.opt)(self.parameters(), lr=1e-2, **opt_kwargs)
|
|
|
|
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, total_steps=10_000)
|
|
|
|
return [optimizer], [lr_scheduler]
|
2020-10-28 16:26:58 +00:00
|
|
|
|
2021-01-05 06:51:22 +00:00
|
|
|
model = LogMomentumModel(opt=opt)
|
2020-10-28 16:26:58 +00:00
|
|
|
lr_monitor = LearningRateMonitor(log_momentum=True)
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=2,
|
2021-01-05 06:51:22 +00:00
|
|
|
limit_val_batches=2,
|
|
|
|
limit_train_batches=5,
|
|
|
|
log_every_n_steps=1,
|
2020-10-28 16:26:58 +00:00
|
|
|
callbacks=[lr_monitor],
|
|
|
|
)
|
2021-01-12 00:36:48 +00:00
|
|
|
trainer.fit(model)
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2020-10-28 16:26:58 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert all(v is not None for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
|
|
|
|
assert len(lr_monitor.last_momentum_values) == len(
|
|
|
|
trainer.lr_schedulers
|
|
|
|
), "Number of momentum values logged does not match number of lr schedulers"
|
|
|
|
assert all(
|
|
|
|
k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values.keys()
|
|
|
|
), "Names of momentum values not set correctly"
|
2021-01-05 06:51:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_log_momentum_no_momentum_optimizer(tmpdir):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Test that if optimizer doesn't have momentum then a warning is raised with log_momentum=True."""
|
2021-02-06 12:28:26 +00:00
|
|
|
|
2021-01-05 06:51:22 +00:00
|
|
|
class LogMomentumModel(BoringModel):
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = optim.ASGD(self.parameters(), lr=1e-2)
|
|
|
|
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
|
|
|
return [optimizer], [lr_scheduler]
|
|
|
|
|
|
|
|
model = LogMomentumModel()
|
|
|
|
lr_monitor = LearningRateMonitor(log_momentum=True)
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
limit_val_batches=2,
|
|
|
|
limit_train_batches=5,
|
|
|
|
log_every_n_steps=1,
|
|
|
|
callbacks=[lr_monitor],
|
|
|
|
)
|
|
|
|
with pytest.warns(RuntimeWarning, match="optimizers do not have momentum."):
|
2021-01-12 00:36:48 +00:00
|
|
|
trainer.fit(model)
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2021-01-05 06:51:22 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert all(v == 0 for v in lr_monitor.last_momentum_values.values()), "Expected momentum to be logged"
|
|
|
|
assert len(lr_monitor.last_momentum_values) == len(
|
|
|
|
trainer.lr_schedulers
|
|
|
|
), "Number of momentum values logged does not match number of lr schedulers"
|
|
|
|
assert all(
|
|
|
|
k == "lr-ASGD-momentum" for k in lr_monitor.last_momentum_values.keys()
|
|
|
|
), "Names of momentum values not set correctly"
|
2020-10-28 16:26:58 +00:00
|
|
|
|
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
def test_lr_monitor_no_lr_scheduler(tmpdir):
|
2020-05-28 02:44:46 +00:00
|
|
|
tutils.reset_seed()
|
|
|
|
|
2021-02-16 19:59:57 +00:00
|
|
|
class CustomBoringModel(BoringModel):
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = optim.SGD(self.parameters(), lr=0.1)
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
model = CustomBoringModel()
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
lr_monitor = LearningRateMonitor()
|
2020-05-28 02:44:46 +00:00
|
|
|
trainer = Trainer(
|
2021-07-26 11:37:35 +00:00
|
|
|
default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]
|
2020-05-28 02:44:46 +00:00
|
|
|
)
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with pytest.warns(RuntimeWarning, match="have no learning rate schedulers"):
|
2021-01-12 00:36:48 +00:00
|
|
|
trainer.fit(model)
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2020-05-28 02:44:46 +00:00
|
|
|
|
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
def test_lr_monitor_no_logger(tmpdir):
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
2021-02-16 19:59:57 +00:00
|
|
|
model = BoringModel()
|
2020-09-03 18:17:15 +00:00
|
|
|
|
|
|
|
lr_monitor = LearningRateMonitor()
|
2021-07-26 11:37:35 +00:00
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[lr_monitor], logger=False)
|
2020-09-03 18:17:15 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
with pytest.raises(MisconfigurationException, match="`Trainer` that has no logger"):
|
2020-09-03 18:17:15 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("logging_interval", ["step", "epoch"])
|
2021-03-09 11:27:15 +00:00
|
|
|
def test_lr_monitor_multi_lrs(tmpdir, logging_interval: str):
|
2021-07-26 11:37:35 +00:00
|
|
|
"""Test that learning rates are extracted and logged for multi lr schedulers."""
|
2020-05-28 02:44:46 +00:00
|
|
|
tutils.reset_seed()
|
|
|
|
|
2021-02-16 19:59:57 +00:00
|
|
|
class CustomBoringModel(BoringModel):
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer1 = optim.Adam(self.parameters(), lr=1e-2)
|
|
|
|
optimizer2 = optim.Adam(self.parameters(), lr=1e-2)
|
|
|
|
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
|
|
|
|
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
|
|
|
|
|
|
|
|
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]
|
|
|
|
|
|
|
|
model = CustomBoringModel()
|
|
|
|
model.training_epoch_end = None
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
lr_monitor = LearningRateMonitor(logging_interval=logging_interval)
|
2020-10-22 11:08:03 +00:00
|
|
|
log_every_n_steps = 2
|
|
|
|
|
2020-05-28 02:44:46 +00:00
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
2020-06-01 15:00:32 +00:00
|
|
|
max_epochs=2,
|
2020-10-22 11:08:03 +00:00
|
|
|
log_every_n_steps=log_every_n_steps,
|
|
|
|
limit_train_batches=7,
|
2020-06-17 12:03:28 +00:00
|
|
|
limit_val_batches=0.1,
|
2020-09-03 18:17:15 +00:00
|
|
|
callbacks=[lr_monitor],
|
2020-05-28 02:44:46 +00:00
|
|
|
)
|
2021-01-12 00:36:48 +00:00
|
|
|
trainer.fit(model)
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs, "No learning rates logged"
|
|
|
|
assert len(lr_monitor.lrs) == len(
|
|
|
|
trainer.lr_schedulers
|
|
|
|
), "Number of learning rates logged does not match number of lr schedulers"
|
|
|
|
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"], "Names of learning rates not set correctly"
|
2020-08-09 16:30:43 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
if logging_interval == "step":
|
2020-10-22 11:08:03 +00:00
|
|
|
expected_number_logged = trainer.global_step // log_every_n_steps
|
2021-07-26 11:37:35 +00:00
|
|
|
if logging_interval == "epoch":
|
2020-08-09 16:30:43 +00:00
|
|
|
expected_number_logged = trainer.max_epochs
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert all(
|
|
|
|
len(lr) == expected_number_logged for lr in lr_monitor.lrs.values()
|
|
|
|
), "Length of logged learning rates do not match the expected number"
|
2020-05-28 02:44:46 +00:00
|
|
|
|
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
def test_lr_monitor_param_groups(tmpdir):
|
2021-07-26 11:37:35 +00:00
|
|
|
"""Test that learning rates are extracted and logged for single lr scheduler."""
|
2020-05-28 02:44:46 +00:00
|
|
|
tutils.reset_seed()
|
|
|
|
|
2021-02-16 19:59:57 +00:00
|
|
|
class CustomClassificationModel(ClassificationModel):
|
|
|
|
def configure_optimizers(self):
|
2021-07-26 11:37:35 +00:00
|
|
|
param_groups = [
|
|
|
|
{"params": list(self.parameters())[:2], "lr": self.lr * 0.1},
|
|
|
|
{"params": list(self.parameters())[2:], "lr": self.lr},
|
|
|
|
]
|
2021-02-16 19:59:57 +00:00
|
|
|
|
|
|
|
optimizer = optim.Adam(param_groups)
|
|
|
|
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
|
|
|
|
return [optimizer], [lr_scheduler]
|
|
|
|
|
|
|
|
model = CustomClassificationModel()
|
|
|
|
dm = ClassifDataModule()
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2020-09-03 18:17:15 +00:00
|
|
|
lr_monitor = LearningRateMonitor()
|
2020-05-28 02:44:46 +00:00
|
|
|
trainer = Trainer(
|
2021-07-26 11:37:35 +00:00
|
|
|
default_root_dir=tmpdir, max_epochs=2, limit_val_batches=0.1, limit_train_batches=0.5, callbacks=[lr_monitor]
|
2020-05-28 02:44:46 +00:00
|
|
|
)
|
2021-02-16 19:59:57 +00:00
|
|
|
trainer.fit(model, datamodule=dm)
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2020-05-28 02:44:46 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs, "No learning rates logged"
|
|
|
|
assert len(lr_monitor.lrs) == 2 * len(
|
|
|
|
trainer.lr_schedulers
|
|
|
|
), "Number of learning rates logged does not match number of param groups"
|
|
|
|
assert lr_monitor.lr_sch_names == ["lr-Adam"]
|
|
|
|
assert list(lr_monitor.lrs.keys()) == ["lr-Adam/pg1", "lr-Adam/pg2"], "Names of learning rates not set correctly"
|
2020-12-14 07:38:10 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lr_monitor_custom_name(tmpdir):
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer, [scheduler] = super().configure_optimizers()
|
2021-07-26 11:37:35 +00:00
|
|
|
lr_scheduler = {"scheduler": scheduler, "name": "my_logging_name"}
|
2020-12-14 07:38:10 +00:00
|
|
|
return optimizer, [lr_scheduler]
|
|
|
|
|
|
|
|
lr_monitor = LearningRateMonitor()
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=2,
|
|
|
|
limit_val_batches=0.1,
|
|
|
|
limit_train_batches=0.5,
|
|
|
|
callbacks=[lr_monitor],
|
2021-09-25 05:53:31 +00:00
|
|
|
enable_progress_bar=False,
|
2021-10-13 11:50:54 +00:00
|
|
|
enable_model_summary=False,
|
2020-12-14 07:38:10 +00:00
|
|
|
)
|
|
|
|
trainer.fit(TestModel())
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lr_sch_names == list(lr_monitor.lrs.keys()) == ["my_logging_name"]
|
2021-06-07 10:17:11 +00:00
|
|
|
|
|
|
|
|
2021-06-17 01:13:54 +00:00
|
|
|
def test_lr_monitor_custom_pg_name(tmpdir):
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def configure_optimizers(self):
|
2021-07-26 11:37:35 +00:00
|
|
|
optimizer = torch.optim.SGD([{"params": list(self.layer.parameters()), "name": "linear"}], lr=0.1)
|
2021-06-17 01:13:54 +00:00
|
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
|
|
|
return [optimizer], [lr_scheduler]
|
|
|
|
|
|
|
|
lr_monitor = LearningRateMonitor()
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=2,
|
|
|
|
limit_val_batches=2,
|
|
|
|
limit_train_batches=2,
|
|
|
|
callbacks=[lr_monitor],
|
2021-09-25 05:53:31 +00:00
|
|
|
enable_progress_bar=False,
|
2021-10-13 11:50:54 +00:00
|
|
|
enable_model_summary=False,
|
2021-06-17 01:13:54 +00:00
|
|
|
)
|
|
|
|
trainer.fit(TestModel())
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lr_sch_names == ["lr-SGD"]
|
|
|
|
assert list(lr_monitor.lrs) == ["lr-SGD/linear"]
|
2021-06-17 01:13:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_lr_monitor_duplicate_custom_pg_names(tmpdir):
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.linear_a = torch.nn.Linear(32, 16)
|
|
|
|
self.linear_b = torch.nn.Linear(16, 2)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.linear_a(x)
|
|
|
|
x = self.linear_b(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
param_groups = [
|
2021-07-26 11:37:35 +00:00
|
|
|
{"params": list(self.linear_a.parameters()), "name": "linear"},
|
|
|
|
{"params": list(self.linear_b.parameters()), "name": "linear"},
|
2021-06-17 01:13:54 +00:00
|
|
|
]
|
|
|
|
optimizer = torch.optim.SGD(param_groups, lr=0.1)
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
|
|
|
return [optimizer], [lr_scheduler]
|
|
|
|
|
|
|
|
lr_monitor = LearningRateMonitor()
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=2,
|
|
|
|
limit_val_batches=2,
|
|
|
|
limit_train_batches=2,
|
|
|
|
callbacks=[lr_monitor],
|
2021-09-25 05:53:31 +00:00
|
|
|
enable_progress_bar=False,
|
2021-10-13 11:50:54 +00:00
|
|
|
enable_model_summary=False,
|
2021-06-17 01:13:54 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
with pytest.raises(
|
2021-07-26 11:37:35 +00:00
|
|
|
MisconfigurationException, match="A single `Optimizer` cannot have multiple parameter groups with identical"
|
2021-06-17 01:13:54 +00:00
|
|
|
):
|
|
|
|
trainer.fit(TestModel())
|
|
|
|
|
|
|
|
|
2021-06-07 10:17:11 +00:00
|
|
|
def test_multiple_optimizers_basefinetuning(tmpdir):
|
|
|
|
class TestModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.backbone = torch.nn.Sequential(
|
2021-07-26 11:37:35 +00:00
|
|
|
torch.nn.Linear(32, 32), torch.nn.Linear(32, 32), torch.nn.Linear(32, 32), torch.nn.ReLU(True)
|
2021-06-07 10:17:11 +00:00
|
|
|
)
|
|
|
|
self.layer = torch.nn.Linear(32, 2)
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.layer(self.backbone(x))
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
parameters = list(filter(lambda p: p.requires_grad, self.parameters()))
|
|
|
|
opt = optim.Adam(parameters, lr=0.1)
|
|
|
|
opt_2 = optim.Adam(parameters, lr=0.1)
|
|
|
|
opt_3 = optim.Adam(parameters, lr=0.1)
|
|
|
|
optimizers = [opt, opt_2, opt_3]
|
|
|
|
schedulers = [
|
|
|
|
optim.lr_scheduler.StepLR(opt, step_size=1, gamma=0.5),
|
|
|
|
optim.lr_scheduler.StepLR(opt_2, step_size=1, gamma=0.5),
|
|
|
|
]
|
|
|
|
return optimizers, schedulers
|
|
|
|
|
|
|
|
class Check(Callback):
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module) -> None:
|
2021-07-26 12:38:12 +00:00
|
|
|
num_param_groups = sum(len(opt.param_groups) for opt in trainer.optimizers)
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lr_sch_names == ["lr-Adam", "lr-Adam-1"]
|
2021-06-07 10:17:11 +00:00
|
|
|
if trainer.current_epoch == 0:
|
|
|
|
assert num_param_groups == 3
|
|
|
|
elif trainer.current_epoch == 1:
|
|
|
|
assert num_param_groups == 4
|
2021-07-26 11:37:35 +00:00
|
|
|
assert list(lr_monitor.lrs) == ["lr-Adam-1", "lr-Adam/pg1", "lr-Adam/pg2"]
|
2021-06-07 10:17:11 +00:00
|
|
|
elif trainer.current_epoch == 2:
|
|
|
|
assert num_param_groups == 5
|
2021-07-26 11:37:35 +00:00
|
|
|
assert list(lr_monitor.lrs) == ["lr-Adam/pg1", "lr-Adam/pg2", "lr-Adam-1/pg1", "lr-Adam-1/pg2"]
|
2021-06-07 10:17:11 +00:00
|
|
|
else:
|
2021-07-26 11:37:35 +00:00
|
|
|
expected = ["lr-Adam/pg1", "lr-Adam/pg2", "lr-Adam-1/pg1", "lr-Adam-1/pg2", "lr-Adam-1/pg3"]
|
2021-06-07 10:17:11 +00:00
|
|
|
assert list(lr_monitor.lrs) == expected
|
|
|
|
|
|
|
|
class TestFinetuning(BackboneFinetuning):
|
|
|
|
def freeze_before_training(self, pl_module):
|
|
|
|
self.freeze(pl_module.backbone[0])
|
|
|
|
self.freeze(pl_module.backbone[1])
|
|
|
|
self.freeze(pl_module.layer)
|
|
|
|
|
|
|
|
def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
|
|
|
|
"""Called when the epoch begins."""
|
|
|
|
if epoch == 1 and opt_idx == 0:
|
|
|
|
self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1)
|
|
|
|
if epoch == 2 and opt_idx == 1:
|
|
|
|
self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1)
|
|
|
|
|
|
|
|
if epoch == 3 and opt_idx == 1:
|
|
|
|
assert len(optimizer.param_groups) == 2
|
|
|
|
self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1)
|
|
|
|
assert len(optimizer.param_groups) == 3
|
|
|
|
|
|
|
|
lr_monitor = LearningRateMonitor()
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=5,
|
|
|
|
limit_val_batches=0,
|
|
|
|
limit_train_batches=2,
|
|
|
|
callbacks=[TestFinetuning(), lr_monitor, Check()],
|
2021-09-25 05:53:31 +00:00
|
|
|
enable_progress_bar=False,
|
2021-10-13 11:50:54 +00:00
|
|
|
enable_model_summary=False,
|
2021-10-12 07:55:07 +00:00
|
|
|
enable_checkpointing=False,
|
2021-06-07 10:17:11 +00:00
|
|
|
)
|
|
|
|
model = TestModel()
|
|
|
|
model.training_epoch_end = None
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
expected = [0.1, 0.05, 0.025, 0.0125, 0.00625]
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs["lr-Adam/pg1"] == expected
|
2021-06-07 10:17:11 +00:00
|
|
|
|
|
|
|
expected = [0.1, 0.05, 0.025, 0.0125]
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs["lr-Adam/pg2"] == expected
|
2021-06-07 10:17:11 +00:00
|
|
|
|
|
|
|
expected = [0.1, 0.05, 0.025, 0.0125, 0.00625]
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs["lr-Adam-1/pg1"] == expected
|
2021-06-07 10:17:11 +00:00
|
|
|
|
|
|
|
expected = [0.1, 0.05, 0.025]
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs["lr-Adam-1/pg2"] == expected
|
2021-06-07 10:17:11 +00:00
|
|
|
|
|
|
|
expected = [0.1, 0.05]
|
2021-07-26 11:37:35 +00:00
|
|
|
assert lr_monitor.lrs["lr-Adam-1/pg3"] == expected
|