Implement freeze batchnorm with freezing track running stats (#15063)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Sitcebelly 2022-11-01 22:11:42 +06:00 committed by GitHub
parent f8675ff8be
commit 94bed87a34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 3 deletions

View File

@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- -
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))
## [1.8.0] - 2022-11-01 ## [1.8.0] - 2022-11-01

View File

@ -164,10 +164,25 @@ class BaseFinetuning(Callback):
""" """
modules = BaseFinetuning.flatten_modules(modules) modules = BaseFinetuning.flatten_modules(modules)
for module in modules: for module in modules:
if isinstance(module, _BatchNorm):
module.track_running_stats = True
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.requires_grad = True param.requires_grad = True
@staticmethod
def freeze_module(module: Module) -> None:
"""Freezes the parameters of the provided module.
Args:
module: A given module
"""
if isinstance(module, _BatchNorm):
module.track_running_stats = False
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in module.parameters(recurse=False):
param.requires_grad = False
@staticmethod @staticmethod
def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None: def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None:
"""Freezes the parameters of the provided modules. """Freezes the parameters of the provided modules.
@ -184,9 +199,7 @@ class BaseFinetuning(Callback):
if isinstance(mod, _BatchNorm) and train_bn: if isinstance(mod, _BatchNorm) and train_bn:
BaseFinetuning.make_trainable(mod) BaseFinetuning.make_trainable(mod)
else: else:
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it BaseFinetuning.freeze_module(mod)
for param in mod.parameters(recurse=False):
param.requires_grad = False
@staticmethod @staticmethod
def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:

View File

@ -147,19 +147,23 @@ def test_freeze_unfreeze_function(tmpdir):
self.backbone = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 2)) self.backbone = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 2))
model = FreezeModel() model = FreezeModel()
assert model.backbone[1].track_running_stats
BaseFinetuning.freeze(model, train_bn=True) BaseFinetuning.freeze(model, train_bn=True)
assert not model.backbone[0].weight.requires_grad assert not model.backbone[0].weight.requires_grad
assert model.backbone[1].weight.requires_grad assert model.backbone[1].weight.requires_grad
assert model.backbone[1].track_running_stats
assert not model.backbone[3].weight.requires_grad assert not model.backbone[3].weight.requires_grad
BaseFinetuning.freeze(model, train_bn=False) BaseFinetuning.freeze(model, train_bn=False)
assert not model.backbone[0].weight.requires_grad assert not model.backbone[0].weight.requires_grad
assert not model.backbone[1].weight.requires_grad assert not model.backbone[1].weight.requires_grad
assert not model.backbone[1].track_running_stats
assert not model.backbone[3].weight.requires_grad assert not model.backbone[3].weight.requires_grad
BaseFinetuning.make_trainable(model) BaseFinetuning.make_trainable(model)
assert model.backbone[0].weight.requires_grad assert model.backbone[0].weight.requires_grad
assert model.backbone[1].weight.requires_grad assert model.backbone[1].weight.requires_grad
assert model.backbone[1].track_running_stats
assert model.backbone[3].weight.requires_grad assert model.backbone[3].weight.requires_grad
BaseFinetuning.freeze(model.backbone[0], train_bn=False) BaseFinetuning.freeze(model.backbone[0], train_bn=False)
@ -167,6 +171,7 @@ def test_freeze_unfreeze_function(tmpdir):
BaseFinetuning.freeze(([(model.backbone[1]), [model.backbone[3]]]), train_bn=True) BaseFinetuning.freeze(([(model.backbone[1]), [model.backbone[3]]]), train_bn=True)
assert model.backbone[1].weight.requires_grad assert model.backbone[1].weight.requires_grad
assert model.backbone[1].track_running_stats
assert not model.backbone[3].weight.requires_grad assert not model.backbone[3].weight.requires_grad