diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index c77d33b1e8..168f03d998 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063)) + ## [1.8.0] - 2022-11-01 diff --git a/src/pytorch_lightning/callbacks/finetuning.py b/src/pytorch_lightning/callbacks/finetuning.py index 11cd81f7a2..0722f7b8e0 100644 --- a/src/pytorch_lightning/callbacks/finetuning.py +++ b/src/pytorch_lightning/callbacks/finetuning.py @@ -164,10 +164,25 @@ class BaseFinetuning(Callback): """ modules = BaseFinetuning.flatten_modules(modules) for module in modules: + if isinstance(module, _BatchNorm): + module.track_running_stats = True # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it for param in module.parameters(recurse=False): param.requires_grad = True + @staticmethod + def freeze_module(module: Module) -> None: + """Freezes the parameters of the provided module. + + Args: + module: A given module + """ + if isinstance(module, _BatchNorm): + module.track_running_stats = False + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in module.parameters(recurse=False): + param.requires_grad = False + @staticmethod def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None: """Freezes the parameters of the provided modules. @@ -184,9 +199,7 @@ class BaseFinetuning(Callback): if isinstance(mod, _BatchNorm) and train_bn: BaseFinetuning.make_trainable(mod) else: - # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it - for param in mod.parameters(recurse=False): - param.requires_grad = False + BaseFinetuning.freeze_module(mod) @staticmethod def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index cd9b1df221..7043c913ad 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -147,19 +147,23 @@ def test_freeze_unfreeze_function(tmpdir): self.backbone = nn.Sequential(nn.Linear(32, 32), nn.BatchNorm1d(32), nn.ReLU(), nn.Linear(32, 2)) model = FreezeModel() + assert model.backbone[1].track_running_stats BaseFinetuning.freeze(model, train_bn=True) assert not model.backbone[0].weight.requires_grad assert model.backbone[1].weight.requires_grad + assert model.backbone[1].track_running_stats assert not model.backbone[3].weight.requires_grad BaseFinetuning.freeze(model, train_bn=False) assert not model.backbone[0].weight.requires_grad assert not model.backbone[1].weight.requires_grad + assert not model.backbone[1].track_running_stats assert not model.backbone[3].weight.requires_grad BaseFinetuning.make_trainable(model) assert model.backbone[0].weight.requires_grad assert model.backbone[1].weight.requires_grad + assert model.backbone[1].track_running_stats assert model.backbone[3].weight.requires_grad BaseFinetuning.freeze(model.backbone[0], train_bn=False) @@ -167,6 +171,7 @@ def test_freeze_unfreeze_function(tmpdir): BaseFinetuning.freeze(([(model.backbone[1]), [model.backbone[3]]]), train_bn=True) assert model.backbone[1].weight.requires_grad + assert model.backbone[1].track_running_stats assert not model.backbone[3].weight.requires_grad