Implement freeze batchnorm with freezing track running stats (#15063)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
f8675ff8be
commit
94bed87a34
|
@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
-
|
||||
|
||||
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))
|
||||
|
||||
|
||||
## [1.8.0] - 2022-11-01
|
||||
|
||||
|
|
|
@ -164,10 +164,25 @@ class BaseFinetuning(Callback):
|
|||
"""
|
||||
modules = BaseFinetuning.flatten_modules(modules)
|
||||
for module in modules:
|
||||
if isinstance(module, _BatchNorm):
|
||||
module.track_running_stats = True
|
||||
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
|
||||
for param in module.parameters(recurse=False):
|
||||
param.requires_grad = True
|
||||
|
||||
@staticmethod
|
||||
def freeze_module(module: Module) -> None:
|
||||
"""Freezes the parameters of the provided module.
|
||||
|
||||
Args:
|
||||
module: A given module
|
||||
"""
|
||||
if isinstance(module, _BatchNorm):
|
||||
module.track_running_stats = False
|
||||
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
|
||||
for param in module.parameters(recurse=False):
|
||||
param.requires_grad = False
|
||||
|
||||
@staticmethod
|
||||
def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None:
|
||||
"""Freezes the parameters of the provided modules.
|
||||
|
@ -184,9 +199,7 @@ class BaseFinetuning(Callback):
|
|||
if isinstance(mod, _BatchNorm) and train_bn:
|
||||
BaseFinetuning.make_trainable(mod)
|
||||
else:
|
||||
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
|
||||
for param in mod.parameters(recurse=False):
|
||||
param.requires_grad = False
|
||||
BaseFinetuning.freeze_module(mod)
|
||||
|
||||
@staticmethod
|
||||
def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List:
|
||||
|
|
|
@ -147,19 +147,23 @@ def test_freeze_unfreeze_function(tmpdir):
|
|||
self.backbone = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 2))
|
||||
|
||||
model = FreezeModel()
|
||||
assert model.backbone[1].track_running_stats
|
||||
BaseFinetuning.freeze(model, train_bn=True)
|
||||
assert not model.backbone[0].weight.requires_grad
|
||||
assert model.backbone[1].weight.requires_grad
|
||||
assert model.backbone[1].track_running_stats
|
||||
assert not model.backbone[3].weight.requires_grad
|
||||
|
||||
BaseFinetuning.freeze(model, train_bn=False)
|
||||
assert not model.backbone[0].weight.requires_grad
|
||||
assert not model.backbone[1].weight.requires_grad
|
||||
assert not model.backbone[1].track_running_stats
|
||||
assert not model.backbone[3].weight.requires_grad
|
||||
|
||||
BaseFinetuning.make_trainable(model)
|
||||
assert model.backbone[0].weight.requires_grad
|
||||
assert model.backbone[1].weight.requires_grad
|
||||
assert model.backbone[1].track_running_stats
|
||||
assert model.backbone[3].weight.requires_grad
|
||||
|
||||
BaseFinetuning.freeze(model.backbone[0], train_bn=False)
|
||||
|
@ -167,6 +171,7 @@ def test_freeze_unfreeze_function(tmpdir):
|
|||
|
||||
BaseFinetuning.freeze(([(model.backbone[1]), [model.backbone[3]]]), train_bn=True)
|
||||
assert model.backbone[1].weight.requires_grad
|
||||
assert model.backbone[1].track_running_stats
|
||||
assert not model.backbone[3].weight.requires_grad
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue