lightning/pytorch_lightning/callbacks/finetuning.py

368 lines
14 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Finetuning Callback
^^^^^^^^^^^^^^^^^^^^
Freeze and unfreeze models for finetuning purposes
"""
import logging
from typing import Callable, Generator, Iterable, List, Optional, Union
import torch
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential
from torch.optim.optimizer import Optimizer
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
log = logging.getLogger(__name__)
def multiplicative(epoch):
return 2
class BaseFinetuning(Callback):
r"""
This class implements the base logic for writing your own Finetuning Callback.
Override ``freeze_before_training`` and ``finetune_function`` methods with your own logic.
``freeze_before_training``: This method is called before ``configure_optimizers``
and should be used to freeze any modules parameters.
``finetune_function``: This method is called on every train epoch start and should be used to
``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group``
within the optimizer.
.. note:: Make sure to filter the parameters based on ``requires_grad``.
Example::
class MyModel(LightningModule)
...
def configure_optimizer(self):
# Make sure to filter the parameters based on `requires_grad`
return Adam(filter(lambda p: p.requires_grad, self.parameters))
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
def __init__(self, unfreeze_at_epoch=10)
self._unfreeze_at_epoch = unfreeze_at_epoch
def freeze_before_training(self, pl_module):
# freeze any module you want
# Here, we are freezing ``feature_extractor``
self.freeze(pl_module.feature_extractor)
def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
# When `current_epoch` is 10, feature_extractor will start training.
if current_epoch == self._unfreeze_at_epoch:
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor,
optimizer=optimizer,
train_bn=True,
)
"""
@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
"""
This function is used to flatten a module or an iterable of modules into a list of its modules.
Args:
modules: A given module or an iterable of modules
Returns:
List of modules
"""
if isinstance(modules, Iterable):
_modules = []
for m in modules:
_modules.extend(BaseFinetuning.flatten_modules(m))
else:
_modules = modules.modules()
return list(
filter(
lambda m: not isinstance(m, (Container, Sequential, ModuleDict, ModuleList, LightningModule)), _modules
)
)
@staticmethod
def filter_params(
modules: Union[Module, Iterable[Union[Module, Iterable]]],
train_bn: bool = True,
requires_grad: bool = True
) -> Generator:
"""Yields the `requires_grad` parameters of a given module or list of modules.
Args:
modules: A given module or an iterable of modules
train_bn: Whether to train BatchNorm module
requires_grad: Whether to create a generator for trainable or non-trainable parameters.
2021-01-26 01:21:00 +00:00
Returns:
Generator
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
if isinstance(mod, _BatchNorm) and not train_bn:
continue
for param in mod.parameters():
if param.requires_grad == requires_grad:
yield param
@staticmethod
def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None:
"""
Unfreezes the parameters of the provided modules
Args:
modules: A given module or an iterable of modules
"""
modules = BaseFinetuning.flatten_modules(modules)
for module in modules:
for param in module.parameters():
param.requires_grad = True
@staticmethod
def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None:
"""
Freezes the parameters of the provided modules
Args:
modules: A given module or an iterable of modules
train_bn: If True, leave the BatchNorm layers in training mode
Returns:
None
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
if isinstance(mod, _BatchNorm) and train_bn:
BaseFinetuning.make_trainable(mod)
else:
for param in mod.parameters():
param.requires_grad = False
@staticmethod
def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
"""
This function is used to exclude any parameter which already exists in
this optimizer
Args:
optimizer: Optimizer used for parameter exclusion
params: Iterable of parameters used to check against the provided optimizer
Returns:
List of parameters not contained in this optimizer param groups
"""
out_params = []
removed_params = []
for param in params:
if not any(torch.equal(p, param) for group in optimizer.param_groups for p in group["params"]):
out_params.append(param)
else:
removed_params.append(param)
if removed_params:
rank_zero_warn(
"The provided params to be freezed already exist within another group of this optimizer."
" Those parameters will be skipped.\n"
"HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ", UserWarning
)
return out_params
@staticmethod
def unfreeze_and_add_param_group(
modules: Union[Module, Iterable[Union[Module, Iterable]]],
optimizer: Optimizer,
lr: Optional[float] = None,
initial_denom_lr: float = 10.,
train_bn: bool = True,
) -> None:
"""
Unfreezes a module and adds its parameters to an optimizer.
Args:
modules: A module or iterable of modules to unfreeze.
Their parameters will be added to an optimizer as a new param group.
optimizer: The provided optimizer will receive new parameters and will add them to
`add_param_group`
lr: Learning rate for the new param group.
initial_denom_lr: If no lr is provided, the learning from the first param group will be used
and divided by initial_denom_lr.
train_bn: Whether to train the BatchNormalization layers.
Returns:
None
"""
BaseFinetuning.make_trainable(modules)
params_lr = optimizer.param_groups[0]['lr'] if lr is None else float(lr)
denom_lr = initial_denom_lr if lr is None else 1.
params = BaseFinetuning.filter_params(modules, train_bn=train_bn, requires_grad=True)
params = BaseFinetuning.filter_on_optimizer(optimizer, params)
if params:
optimizer.add_param_group({
'params': params,
'lr': params_lr / denom_lr,
})
def on_before_accelerator_backend_setup(self, trainer, pl_module):
self.freeze_before_training(pl_module)
def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.train_loop.prepare_optimizers():
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
"""
Override to add your unfreeze logic
"""
raise NotImplementedError
def freeze_before_training(self, pl_module: LightningModule):
"""
Override to add your freeze logic
"""
raise NotImplementedError
class BackboneFinetuning(BaseFinetuning):
r"""
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and ``should_align`` is set to True, it will align with it for the rest of the training.
Args:
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
lambda_func: Scheduling function for increasing backbone learning rate.
backbone_initial_ratio_lr:
Used to scale down the backbone learning rate compared to rest of model
backbone_initial_lr: Optional, Inital learning rate for the backbone.
By default, we will use current_learning / backbone_initial_ratio_lr
should_align: Wheter to align with current learning rate when backbone learning
reaches it.
initial_denom_lr: When unfreezing the backbone, the intial learning rate will
current_learning_rate / initial_denom_lr.
train_bn: Wheter to make Batch Normalization trainable.
verbose: Display current learning rate for model and backbone
round: Precision for displaying learning rate
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
"""
def __init__(
self,
unfreeze_backbone_at_epoch: int = 10,
lambda_func: Callable = multiplicative,
backbone_initial_ratio_lr: float = 10e-2,
backbone_initial_lr: Optional[float] = None,
should_align: bool = True,
initial_denom_lr: float = 10.,
train_bn: bool = True,
verbose: bool = False,
round: int = 12,
):
self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch
self.backbone_initial_lr = backbone_initial_lr
self.lambda_func = lambda_func
self.backbone_initial_ratio_lr = backbone_initial_ratio_lr
self.should_align = should_align
self.initial_denom_lr = initial_denom_lr
self.train_bn = train_bn
self.round = round
self.verbose = verbose
def on_fit_start(self, trainer, pl_module):
"""
Raises:
MisconfigurationException:
If LightningModule has no nn.Module `backbone` attribute.
"""
if hasattr(pl_module, "backbone") and isinstance(pl_module.backbone, Module):
return
raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
def freeze_before_training(self, pl_module: LightningModule):
self.freeze(pl_module.backbone)
def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
"""Called when the epoch begins."""
if epoch == self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]['lr']
initial_backbone_lr = self.backbone_initial_lr if self.backbone_initial_lr is not None \
else current_lr * self.backbone_initial_ratio_lr
self.previous_backbone_lr = initial_backbone_lr
self.unfreeze_and_add_param_group(
pl_module.backbone,
optimizer,
initial_backbone_lr,
train_bn=self.train_bn,
initial_denom_lr=self.initial_denom_lr
)
if self.verbose:
log.info(
f"Current lr: {round(current_lr, self.round)}, "
f"Backbone lr: {round(initial_backbone_lr, self.round)}"
)
elif epoch > self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]['lr']
next_current_backbone_lr = self.lambda_func(epoch + 1) * self.previous_backbone_lr
next_current_backbone_lr = current_lr if (self.should_align and next_current_backbone_lr > current_lr) \
else next_current_backbone_lr
optimizer.param_groups[-1]["lr"] = next_current_backbone_lr
self.previous_backbone_lr = next_current_backbone_lr
if self.verbose:
log.info(
f"Current lr: {round(current_lr, self.round)}, "
f"Backbone lr: {round(next_current_backbone_lr, self.round)}"
)