238 lines
9.6 KiB
Python
238 lines
9.6 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
r"""
|
|
Finetunning Callback
|
|
^^^^^^^^^^^^^^^^^^^^
|
|
Freeze and unfreeze models for finetunning purposes
|
|
"""
|
|
from typing import Callable, Generator, Optional
|
|
|
|
import torch
|
|
from torch.nn import Module
|
|
from torch.nn.modules.container import Sequential
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
from pytorch_lightning import _logger as log
|
|
from pytorch_lightning.callbacks.base import Callback
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
|
|
|
|
def multiplicative(epoch):
|
|
return 2
|
|
|
|
|
|
class BaseFinetuningCallback(Callback):
|
|
|
|
r"""
|
|
BaseFinetuningCallback.
|
|
Overrides any functions with your own logic.
|
|
"""
|
|
|
|
BN_TYPES = (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)
|
|
|
|
@staticmethod
|
|
def _make_trainable(module: Module) -> None:
|
|
"""Unfreezes a given module.
|
|
Args:
|
|
module: The module to unfreeze
|
|
"""
|
|
for param in module.parameters():
|
|
param.requires_grad = True
|
|
module.train()
|
|
|
|
@staticmethod
|
|
def _recursive_freeze(module: Module,
|
|
train_bn: bool = True) -> None:
|
|
"""Freezes the layers of a given module.
|
|
Args:
|
|
module: The module to freeze
|
|
train_bn: If True, leave the BatchNorm layers in training mode
|
|
"""
|
|
children = list(module.children())
|
|
if not children:
|
|
if not (isinstance(module, BaseFinetuningCallback.BN_TYPES) and train_bn):
|
|
for param in module.parameters():
|
|
param.requires_grad = False
|
|
module.eval()
|
|
else:
|
|
# Make the BN layers trainable
|
|
BaseFinetuningCallback._make_trainable(module)
|
|
else:
|
|
for child in children:
|
|
BaseFinetuningCallback._recursive_freeze(module=child, train_bn=train_bn)
|
|
|
|
@staticmethod
|
|
def filter_params(module: Module,
|
|
train_bn: bool = True) -> Generator:
|
|
"""Yields the trainable parameters of a given module.
|
|
Args:
|
|
module: A given module
|
|
train_bn: If True, leave the BatchNorm layers in training mode
|
|
Returns:
|
|
Generator
|
|
"""
|
|
children = list(module.children())
|
|
if not children:
|
|
if not (isinstance(module, BaseFinetuningCallback.BN_TYPES) and train_bn):
|
|
for param in module.parameters():
|
|
if param.requires_grad:
|
|
yield param
|
|
else:
|
|
for child in children:
|
|
for param in BaseFinetuningCallback.filter_params(module=child, train_bn=train_bn):
|
|
yield param
|
|
|
|
@staticmethod
|
|
def freeze(module: Module, train_bn: bool = True) -> None:
|
|
"""Freezes the layers up to index n (if n is not None).
|
|
Args:
|
|
module: The module to freeze (at least partially)
|
|
train_bn: If True, leave the BatchNorm layers in training mode
|
|
"""
|
|
for mod in module.parameters():
|
|
if (isinstance(mod, BaseFinetuningCallback.BN_TYPES) and train_bn):
|
|
BaseFinetuningCallback._make_trainable(mod)
|
|
else:
|
|
mod.requires_grad = False
|
|
|
|
@staticmethod
|
|
def unfreeze_and_add_param_group(
|
|
module: Module,
|
|
optimizer: Optimizer,
|
|
lr: Optional[float] = None,
|
|
train_bn: bool = True,
|
|
initial_denom_lr: float = 10.,
|
|
):
|
|
"""Unfreezes a module and adds its parameters to an optimizer."""
|
|
BaseFinetuningCallback._make_trainable(module)
|
|
params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr)
|
|
denom_lr = initial_denom_lr if lr is None else 1.
|
|
optimizer.add_param_group(
|
|
{
|
|
'params': BaseFinetuningCallback.filter_params(module=module, train_bn=train_bn),
|
|
'lr': params_lr / denom_lr,
|
|
}
|
|
)
|
|
|
|
def on_before_accelerator_backend_setup(self, _, pl_module):
|
|
self.freeze_before_training(pl_module)
|
|
|
|
def on_train_epoch_start(self, trainer, pl_module):
|
|
"""Called when the epoch begins."""
|
|
for opt_idx, optimizer in trainer.train_loop.prepare_optimizers():
|
|
self.finetunning_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
|
|
|
|
def finetunning_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
|
|
raise NotImplementedError
|
|
|
|
def freeze_before_training(self, pl_module: LightningModule):
|
|
raise NotImplementedError
|
|
|
|
|
|
class BackboneLambdaFinetuningCallback(BaseFinetuningCallback):
|
|
|
|
r"""
|
|
Finetunne a backbone model based on a learning rate user-defined scheduling.
|
|
When the backbone learning rate reaches the current model learning rate
|
|
and ``should_align`` is set to True, it will align with it for the rest of the training.
|
|
Args:
|
|
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
|
|
lambda_func: Scheduling function for increasing backbone learning rate.
|
|
verbose: verbosity mode. Default: ``False``.
|
|
backbone_initial_ratio_lr:
|
|
Used to scale down the backbone learning rate compared to rest of model
|
|
backbone_initial_lr: Optional, Inital learning rate for the backbone.
|
|
By default, we will use current_learning / backbone_initial_ratio_lr
|
|
should_align: Wheter to align with current learning rate when backbone learning
|
|
reaches it.
|
|
initial_denom_lr: When unfreezing the backbone, the intial learning rate will
|
|
current_learning_rate / initial_denom_lr.
|
|
train_bn: Wheter to make Batch Normalization trainable.
|
|
should_align: Wheter to align with current learning rate when backbone learning
|
|
reaches it.
|
|
verbose: Display current learning rate for model and backbone
|
|
round: Precision for displaying learning rate
|
|
Example::
|
|
>>> from pytorch_lightning import Trainer
|
|
>>> from pytorch_lightning.callbacks import BackboneLambdaFinetuningCallback
|
|
>>> multiplicative = lambda epoch: 1.5
|
|
>>> backbone_finetunning = BackboneLambdaFinetuningCallback(200, multiplicative)
|
|
>>> trainer = Trainer(callbacks=[backbone_finetunning])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
unfreeze_backbone_at_epoch: int = 10,
|
|
lambda_func: Callable = multiplicative,
|
|
backbone_initial_ratio_lr: float = 10e-2,
|
|
backbone_initial_lr: Optional[float] = None,
|
|
should_align: bool = True,
|
|
initial_denom_lr: float = 10.,
|
|
train_bn: bool = True,
|
|
verbose: bool = False,
|
|
round: int = 12,
|
|
):
|
|
self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch
|
|
self.backbone_initial_lr = backbone_initial_lr
|
|
self.lambda_func = lambda_func
|
|
self.backbone_initial_ratio_lr = backbone_initial_ratio_lr
|
|
self.should_align = should_align
|
|
self.initial_denom_lr = initial_denom_lr
|
|
self.train_bn = train_bn
|
|
self.round = round
|
|
self.verbose = verbose
|
|
|
|
def on_fit_start(self, trainer, pl_module):
|
|
if hasattr(pl_module, "backbone") and \
|
|
(isinstance(pl_module.backbone, Module) or isinstance(pl_module.backbone, Sequential)):
|
|
return
|
|
raise MisconfigurationException(
|
|
"The LightningModule should have a nn.Module `backbone` attribute"
|
|
)
|
|
|
|
def freeze_before_training(self, pl_module: LightningModule):
|
|
self.freeze(pl_module.backbone)
|
|
|
|
def finetunning_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
|
|
"""Called when the epoch begins."""
|
|
|
|
if epoch == self.unfreeze_backbone_at_epoch:
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
initial_backbone_lr = self.backbone_initial_lr if self.backbone_initial_lr is not None \
|
|
else current_lr * self.backbone_initial_ratio_lr
|
|
self.previous_backbone_lr = initial_backbone_lr
|
|
self.unfreeze_and_add_param_group(
|
|
pl_module.backbone,
|
|
optimizer,
|
|
initial_backbone_lr,
|
|
train_bn=self.train_bn,
|
|
initial_denom_lr=self.initial_denom_lr
|
|
)
|
|
if self.verbose:
|
|
log.info(f"Current lr: {round(current_lr, self.round)}, "
|
|
f"Backbone lr: {round(initial_backbone_lr, self.round)}")
|
|
|
|
elif epoch > self.unfreeze_backbone_at_epoch:
|
|
current_lr = optimizer.param_groups[0]['lr']
|
|
next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr
|
|
next_current_backbone_lr = current_lr if (self.should_align and next_current_backbone_lr > current_lr) \
|
|
else next_current_backbone_lr
|
|
optimizer.param_groups[-1]["lr"] = next_current_backbone_lr
|
|
self.previous_backbone_lr = next_current_backbone_lr
|
|
if self.verbose:
|
|
log.info(f"Current lr: {round(current_lr, self.round)}, "
|
|
f"Backbone lr: {round(next_current_backbone_lr, self.round)}")
|