rename callback FineTune arg `round` (#9711)

* rename CB Tune arg round

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jirka Borovec 2021-10-06 10:39:36 +02:00 committed by GitHub
parent f94faa9cd3
commit b3e9dff32d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 67 deletions

View File

@ -53,32 +53,29 @@ class BaseFinetuning(Callback):
Example::
class MyModel(LightningModule)
...
def configure_optimizer(self):
# Make sure to filter the parameters based on `requires_grad`
return Adam(filter(lambda p: p.requires_grad, self.parameters))
class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
def __init__(self, unfreeze_at_epoch=10):
self._unfreeze_at_epoch = unfreeze_at_epoch
def freeze_before_training(self, pl_module):
# freeze any module you want
# Here, we are freezing `feature_extractor`
self.freeze(pl_module.feature_extractor)
def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
# When `current_epoch` is 10, feature_extractor will start training.
if current_epoch == self._unfreeze_at_epoch:
self.unfreeze_and_add_param_group(
modules=pl_module.feature_extractor,
optimizer=optimizer,
train_bn=True,
)
>>> from torch.optim import Adam
>>> class MyModel(pl.LightningModule):
... def configure_optimizer(self):
... # Make sure to filter the parameters based on `requires_grad`
... return Adam(filter(lambda p: p.requires_grad, self.parameters))
...
>>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
... def __init__(self, unfreeze_at_epoch=10):
... self._unfreeze_at_epoch = unfreeze_at_epoch
...
... def freeze_before_training(self, pl_module):
... # freeze any module you want
... # Here, we are freezing `feature_extractor`
... self.freeze(pl_module.feature_extractor)
...
... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
... # When `current_epoch` is 10, feature_extractor will start training.
... if current_epoch == self._unfreeze_at_epoch:
... self.unfreeze_and_add_param_group(
... modules=pl_module.feature_extractor,
... optimizer=optimizer,
... train_bn=True,
... )
"""
def __init__(self):
@ -208,7 +205,7 @@ class BaseFinetuning(Callback):
if removed_params:
rank_zero_warn(
"The provided params to be freezed already exist within another group of this optimizer."
"The provided params to be frozen already exist within another group of this optimizer."
" Those parameters will be skipped.\n"
"HINT: Did you init your optimizer in `configure_optimizer` as such:\n"
f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",
@ -227,22 +224,14 @@ class BaseFinetuning(Callback):
"""Unfreezes a module and adds its parameters to an optimizer.
Args:
modules: A module or iterable of modules to unfreeze.
Their parameters will be added to an optimizer as a new param group.
optimizer: The provided optimizer will receive new parameters and will add them to
`add_param_group`
lr: Learning rate for the new param group.
initial_denom_lr: If no lr is provided, the learning from the first param group will be used
and divided by initial_denom_lr.
and divided by `initial_denom_lr`.
train_bn: Whether to train the BatchNormalization layers.
Returns:
None
"""
BaseFinetuning.make_trainable(modules)
params_lr = optimizer.param_groups[0]["lr"] if lr is None else float(lr)
@ -252,7 +241,7 @@ class BaseFinetuning(Callback):
if params:
optimizer.add_param_group({"params": params, "lr": params_lr / denom_lr})
def on_before_accelerator_backend_setup(self, trainer, pl_module):
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
self.freeze_before_training(pl_module)
@staticmethod
@ -283,7 +272,7 @@ class BaseFinetuning(Callback):
self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping)
)
def on_train_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when the epoch begins."""
# import is here to avoid circular imports
from pytorch_lightning.loops.utilities import _get_active_optimizers
@ -294,45 +283,37 @@ class BaseFinetuning(Callback):
current_param_groups = optimizer.param_groups
self._store(pl_module, opt_idx, num_param_groups, current_param_groups)
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int):
def finetune_function(
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
"""Override to add your unfreeze logic."""
raise NotImplementedError
def freeze_before_training(self, pl_module: "pl.LightningModule"):
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
"""Override to add your freeze logic."""
raise NotImplementedError
class BackboneFinetuning(BaseFinetuning):
r"""
r"""Finetune a backbone model based on a learning rate user-defined scheduling.
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and ``should_align`` is set to True, it will align with it for the rest of the training.
Args:
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
lambda_func: Scheduling function for increasing backbone learning rate.
backbone_initial_ratio_lr:
Used to scale down the backbone learning rate compared to rest of model
backbone_initial_lr: Optional, Inital learning rate for the backbone.
By default, we will use current_learning / backbone_initial_ratio_lr
By default, we will use ``current_learning / backbone_initial_ratio_lr``
should_align: Wheter to align with current learning rate when backbone learning
reaches it.
initial_denom_lr: When unfreezing the backbone, the intial learning rate will
current_learning_rate / initial_denom_lr.
``current_learning_rate / initial_denom_lr``.
train_bn: Wheter to make Batch Normalization trainable.
verbose: Display current learning rate for model and backbone
round: Precision for displaying learning rate
rounding: Precision for displaying learning rate
Example::
@ -354,8 +335,8 @@ class BackboneFinetuning(BaseFinetuning):
initial_denom_lr: float = 10.0,
train_bn: bool = True,
verbose: bool = False,
round: int = 12,
):
rounding: int = 12,
) -> None:
super().__init__()
self.unfreeze_backbone_at_epoch: int = unfreeze_backbone_at_epoch
@ -366,12 +347,12 @@ class BackboneFinetuning(BaseFinetuning):
self.initial_denom_lr: float = initial_denom_lr
self.train_bn: bool = train_bn
self.verbose: bool = verbose
self.round: int = round
self.rounding: int = rounding
self.previous_backbone_lr: Optional[float] = None
def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> Dict[int, Any]:
) -> Dict[str, Any]:
return {
"internal_optimizer_metadata": self._internal_optimizer_metadata,
"previous_backbone_lr": self.previous_backbone_lr,
@ -383,7 +364,7 @@ class BackboneFinetuning(BaseFinetuning):
self.previous_backbone_lr = callback_state["previous_backbone_lr"]
super().on_load_checkpoint(trainer, pl_module, callback_state["internal_optimizer_metadata"])
def on_fit_start(self, trainer, pl_module):
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""
Raises:
MisconfigurationException:
@ -393,10 +374,12 @@ class BackboneFinetuning(BaseFinetuning):
return super().on_fit_start(trainer, pl_module)
raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
def freeze_before_training(self, pl_module: "pl.LightningModule"):
def freeze_before_training(self, pl_module: "pl.LightningModule") -> None:
self.freeze(pl_module.backbone)
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int):
def finetune_function(
self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int
) -> None:
"""Called when the epoch begins."""
if epoch == self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]["lr"]
@ -415,8 +398,8 @@ class BackboneFinetuning(BaseFinetuning):
)
if self.verbose:
log.info(
f"Current lr: {round(current_lr, self.round)}, "
f"Backbone lr: {round(initial_backbone_lr, self.round)}"
f"Current lr: {round(current_lr, self.rounding)}, "
f"Backbone lr: {round(initial_backbone_lr, self.rounding)}"
)
elif epoch > self.unfreeze_backbone_at_epoch:
@ -431,6 +414,6 @@ class BackboneFinetuning(BaseFinetuning):
self.previous_backbone_lr = next_current_backbone_lr
if self.verbose:
log.info(
f"Current lr: {round(current_lr, self.round)}, "
f"Backbone lr: {round(next_current_backbone_lr, self.round)}"
f"Current lr: {round(current_lr, self.rounding)}, "
f"Backbone lr: {round(next_current_backbone_lr, self.rounding)}"
)

View File

@ -186,7 +186,7 @@ def test_unfreeze_and_add_param_group_function(tmpdir):
model = FreezeModel()
optimizer = SGD(model.backbone[0].parameters(), lr=0.01)
with pytest.warns(UserWarning, match="The provided params to be freezed already"):
with pytest.warns(UserWarning, match="The provided params to be frozen already"):
BaseFinetuning.unfreeze_and_add_param_group(model.backbone[0], optimizer=optimizer)
assert optimizer.param_groups[0]["lr"] == 0.01
@ -197,7 +197,7 @@ def test_unfreeze_and_add_param_group_function(tmpdir):
assert torch.equal(optimizer.param_groups[1]["params"][0], model.backbone[1].weight)
assert model.backbone[1].weight.requires_grad
with pytest.warns(UserWarning, match="The provided params to be freezed already"):
with pytest.warns(UserWarning, match="The provided params to be frozen already"):
BaseFinetuning.unfreeze_and_add_param_group(model, optimizer=optimizer, lr=100, train_bn=False)
assert len(optimizer.param_groups) == 3
assert optimizer.param_groups[2]["lr"] == 100