rename callback FineTune arg `round` (#9711)
* rename CB Tune arg round Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f94faa9cd3
commit
b3e9dff32d
|
@ -53,32 +53,29 @@ class BaseFinetuning(Callback):
|
|||
|
||||
Example::
|
||||
|
||||
class MyModel(LightningModule)
|
||||
|
||||
...
|
||||
|
||||
def configure_optimizer(self):
|
||||
# Make sure to filter the parameters based on `requires_grad`
|
||||
return Adam(filter(lambda p: p.requires_grad, self.parameters))
|
||||
|
||||
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
|
||||
|
||||
def __init__(self, unfreeze_at_epoch=10):
|
||||
self._unfreeze_at_epoch = unfreeze_at_epoch
|
||||
|
||||
def freeze_before_training(self, pl_module):
|
||||
# freeze any module you want
|
||||
# Here, we are freezing `feature_extractor`
|
||||
self.freeze(pl_module.feature_extractor)
|
||||
|
||||
def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
|
||||
# When `current_epoch` is 10, feature_extractor will start training.
|
||||
if current_epoch == self._unfreeze_at_epoch:
|
||||
self.unfreeze_and_add_param_group(
|
||||
modules=pl_module.feature_extractor,
|
||||
optimizer=optimizer,
|
||||
train_bn=True,
|
||||
)
|
||||
>>> from torch.optim import Adam
|
||||
>>> class MyModel(pl.LightningModule):
|
||||
... def configure_optimizer(self):
|
||||
... # Make sure to filter the parameters based on `requires_grad`
|
||||
... return Adam(filter(lambda p: p.requires_grad, self.parameters))
|
||||
...
|
||||
>>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
|
||||
... def __init__(self, unfreeze_at_epoch=10):
|
||||
... self._unfreeze_at_epoch = unfreeze_at_epoch
|
||||
...
|
||||
... def freeze_before_training(self, pl_module):
|
||||
... # freeze any module you want
|
||||
... # Here, we are freezing `feature_extractor`
|
||||
... self.freeze(pl_module.feature_extractor)
|
||||
...
|
||||
... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
|
||||
... # When `current_epoch` is 10, feature_extractor will start training.
|
||||
... if current_epoch == self._unfreeze_at_epoch:
|
||||
... self.unfreeze_and_add_param_group(
|
||||
... modules=pl_module.feature_extractor,
|
||||
... optimizer=optimizer,
|
||||
... train_bn=True,
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -208,7 +205,7 @@ class BaseFinetuning(Callback):
|
|||
|
||||
if removed_params:
|
||||
rank_zero_warn(
|
||||
"The provided params to be freezed already exist within another group of this optimizer."
|
||||
"The provided params to be frozen already exist within another group of this optimizer."
|
||||
" Those parameters will be skipped.\n"
|
||||
"HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
|
||||
f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",
|
||||
|
@ -227,22 +224,14 @@ class BaseFinetuning(Callback):
|
|||
"""Unfreezes a module and adds its parameters to an optimizer.
|
||||
|
||||
Args:
|
||||
|
||||
modules: A module or iterable of modules to unfreeze.
|
||||
Their parameters will be added to an optimizer as a new param group.
|
||||
|
||||
optimizer: The provided optimizer will receive new parameters and will add them to
|
||||
`add_param_group`
|
||||
|
||||
lr: Learning rate for the new param group.
|
||||
|
||||
initial_denom_lr: If no lr is provided, the learning from the first param group will be used
|
||||
and divided by initial_denom_lr.
|
||||
|
||||
and divided by `initial_denom_lr`.
|
||||
train_bn: Whether to train the BatchNormalization layers.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
BaseFinetuning.make_trainable(modules)
|
||||
params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr)
|
||||
|
@ -252,7 +241,7 @@ class BaseFinetuning(Callback):
|
|||
if params:
|
||||
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
|
||||
|
||||
def on_before_accelerator_backend_setup(self, trainer, pl_module):
|
||||
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
|
||||
self.freeze_before_training(pl_module)
|
||||
|
||||
@staticmethod
|
||||
|
@ -283,7 +272,7 @@ class BaseFinetuning(Callback):
|
|||
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
|
||||
)
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called when the epoch begins."""
|
||||
# import is here to avoid circular imports
|
||||
from pytorch_lightning.loops.utilities import _get_active_optimizers
|
||||
|
@ -294,45 +283,37 @@ class BaseFinetuning(Callback):
|
|||
current_param_groups = optimizer.param_groups
|
||||
self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
|
||||
|
||||
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int):
|
||||
def finetune_function(
|
||||
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
|
||||
) -> None:
|
||||
"""Override to add your unfreeze logic."""
|
||||
raise NotImplementedError
|
||||
|
||||
def freeze_before_training(self, pl_module: "pl.LightningModule"):
|
||||
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
|
||||
"""Override to add your freeze logic."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BackboneFinetuning(BaseFinetuning):
|
||||
r"""
|
||||
r"""Finetune a backbone model based on a learning rate user-defined scheduling.
|
||||
|
||||
Finetune a backbone model based on a learning rate user-defined scheduling.
|
||||
When the backbone learning rate reaches the current model learning rate
|
||||
and ``should_align`` is set to True, it will align with it for the rest of the training.
|
||||
|
||||
Args:
|
||||
|
||||
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
|
||||
|
||||
lambda_func: Scheduling function for increasing backbone learning rate.
|
||||
|
||||
backbone_initial_ratio_lr:
|
||||
Used to scale down the backbone learning rate compared to rest of model
|
||||
|
||||
backbone_initial_lr: Optional, Inital learning rate for the backbone.
|
||||
By default, we will use current_learning / backbone_initial_ratio_lr
|
||||
|
||||
By default, we will use ``current_learning / backbone_initial_ratio_lr``
|
||||
should_align: Wheter to align with current learning rate when backbone learning
|
||||
reaches it.
|
||||
|
||||
initial_denom_lr: When unfreezing the backbone, the intial learning rate will
|
||||
current_learning_rate / initial_denom_lr.
|
||||
|
||||
``current_learning_rate / initial_denom_lr``.
|
||||
train_bn: Wheter to make Batch Normalization trainable.
|
||||
|
||||
verbose: Display current learning rate for model and backbone
|
||||
|
||||
round: Precision for displaying learning rate
|
||||
rounding: Precision for displaying learning rate
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -354,8 +335,8 @@ class BackboneFinetuning(BaseFinetuning):
|
|||
initial_denom_lr: float = 10.0,
|
||||
train_bn: bool = True,
|
||||
verbose: bool = False,
|
||||
round: int = 12,
|
||||
):
|
||||
rounding: int = 12,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch
|
||||
|
@ -366,12 +347,12 @@ class BackboneFinetuning(BaseFinetuning):
|
|||
self.initial_denom_lr: float = initial_denom_lr
|
||||
self.train_bn: bool = train_bn
|
||||
self.verbose: bool = verbose
|
||||
self.round: int = round
|
||||
self.rounding: int = rounding
|
||||
self.previous_backbone_lr: Optional[float] = None
|
||||
|
||||
def on_save_checkpoint(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
|
||||
) -> Dict[int, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"internal_optimizer_metadata": self._internal_optimizer_metadata,
|
||||
"previous_backbone_lr": self.previous_backbone_lr,
|
||||
|
@ -383,7 +364,7 @@ class BackboneFinetuning(BaseFinetuning):
|
|||
self.previous_backbone_lr = callback_state["previous_backbone_lr"]
|
||||
super().on_load_checkpoint(trainer, pl_module, callback_state["internal_optimizer_metadata"])
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
@ -393,10 +374,12 @@ class BackboneFinetuning(BaseFinetuning):
|
|||
return super().on_fit_start(trainer, pl_module)
|
||||
raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
|
||||
|
||||
def freeze_before_training(self, pl_module: "pl.LightningModule"):
|
||||
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
|
||||
self.freeze(pl_module.backbone)
|
||||
|
||||
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int):
|
||||
def finetune_function(
|
||||
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
|
||||
) -> None:
|
||||
"""Called when the epoch begins."""
|
||||
if epoch == self.unfreeze_backbone_at_epoch:
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
@ -415,8 +398,8 @@ class BackboneFinetuning(BaseFinetuning):
|
|||
)
|
||||
if self.verbose:
|
||||
log.info(
|
||||
f"Current lr: {round(current_lr, self.round)}, "
|
||||
f"Backbone lr: {round(initial_backbone_lr, self.round)}"
|
||||
f"Current lr: {round(current_lr, self.rounding)}, "
|
||||
f"Backbone lr: {round(initial_backbone_lr, self.rounding)}"
|
||||
)
|
||||
|
||||
elif epoch > self.unfreeze_backbone_at_epoch:
|
||||
|
@ -431,6 +414,6 @@ class BackboneFinetuning(BaseFinetuning):
|
|||
self.previous_backbone_lr = next_current_backbone_lr
|
||||
if self.verbose:
|
||||
log.info(
|
||||
f"Current lr: {round(current_lr, self.round)}, "
|
||||
f"Backbone lr: {round(next_current_backbone_lr, self.round)}"
|
||||
f"Current lr: {round(current_lr, self.rounding)}, "
|
||||
f"Backbone lr: {round(next_current_backbone_lr, self.rounding)}"
|
||||
)
|
||||
|
|
|
@ -186,7 +186,7 @@ def test_unfreeze_and_add_param_group_function(tmpdir):
|
|||
model = FreezeModel()
|
||||
optimizer = SGD(model.backbone[0].parameters(), lr=0.01)
|
||||
|
||||
with pytest.warns(UserWarning, match="The provided params to be freezed already"):
|
||||
with pytest.warns(UserWarning, match="The provided params to be frozen already"):
|
||||
BaseFinetuning.unfreeze_and_add_param_group(model.backbone[0], optimizer=optimizer)
|
||||
assert optimizer.param_groups[0]["lr"] == 0.01
|
||||
|
||||
|
@ -197,7 +197,7 @@ def test_unfreeze_and_add_param_group_function(tmpdir):
|
|||
assert torch.equal(optimizer.param_groups[1]["params"][0], model.backbone[1].weight)
|
||||
assert model.backbone[1].weight.requires_grad
|
||||
|
||||
with pytest.warns(UserWarning, match="The provided params to be freezed already"):
|
||||
with pytest.warns(UserWarning, match="The provided params to be frozen already"):
|
||||
BaseFinetuning.unfreeze_and_add_param_group(model, optimizer=optimizer, lr=100, train_bn=False)
|
||||
assert len(optimizer.param_groups) == 3
|
||||
assert optimizer.param_groups[2]["lr"] == 100
|
||||
|
|
Loading…
Reference in New Issue