From bf2af0e591e7823824b3343b7c404a60efdd6cd4 Mon Sep 17 00:00:00 2001 From: "M. Fox" <120434191+lightningforever@users.noreply.github.com> Date: Wed, 21 Jun 2023 06:25:05 -0700 Subject: [PATCH] Enable `Fabric.call` to call hooks on the LightningModule (#17874) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/fabric/CHANGELOG.md | 3 +++ src/lightning/fabric/fabric.py | 4 +++ .../models/test_fabric_integration.py | 25 +++++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 58ef785788..dcc1a2583c 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623)) +- Added support for calling hooks on a LightningModule via `Fabric.call` ([#17874](https://github.com/Lightning-AI/lightning/pull/17874)) + + - Added the parameter `Fabric.load(..., strict=True|False)` to enable non-strict loading of partial checkpoint state ([#17645](https://github.com/Lightning-AI/lightning/pull/17645)) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index ba4ef13e31..c6dcd102d3 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -230,6 +230,8 @@ class Fabric: if hasattr(original_module, "_fabric"): # this is probably a LightningModule original_module._fabric = self # type: ignore[assignment] original_module._fabric_optimizers = optimizers # type: ignore[assignment] + if original_module not in self._callbacks: + self._callbacks.append(original_module) self.call("on_after_setup", fabric=self, module=module) @@ -271,6 +273,8 @@ class Fabric: if hasattr(original_module, "_fabric"): # this is probably a LightningModule original_module._fabric = self # type: ignore[assignment] + if original_module not in self._callbacks: + self._callbacks.append(original_module) self._models_setup += 1 return module diff --git a/tests/tests_pytorch/models/test_fabric_integration.py b/tests/tests_pytorch/models/test_fabric_integration.py index 957f5c8f95..8c6203b752 100644 --- a/tests/tests_pytorch/models/test_fabric_integration.py +++ b/tests/tests_pytorch/models/test_fabric_integration.py @@ -13,6 +13,7 @@ # limitations under the License. from copy import deepcopy +from unittest.mock import Mock import torch @@ -58,3 +59,27 @@ def test_fabric_boring_lightning_module_manual(): model.training_step(batch, 0) # .backward() and optimizer.step() happen inside training_step() assert all(not torch.equal(before, after) for before, after in zip(parameters_before, model.parameters())) + + +def test_fabric_call_lightning_module_hooks(): + """Test that `Fabric.call` can call hooks on the LightningModule.""" + + class HookedModel(BoringModel): + def on_train_start(self): + pass + + def on_my_custom_hook(self, arg, kwarg=None): + pass + + fabric = Fabric(accelerator="cpu", devices=1) + module = Mock(wraps=HookedModel()) + + _ = fabric.setup(module) + _ = fabric.setup(module) # shouldn't add module to callbacks a second time + assert fabric._callbacks == [module] + + fabric.call("on_train_start") + module.on_train_start.assert_called_once_with() + + fabric.call("on_my_custom_hook", 1, kwarg="test") + module.on_my_custom_hook.assert_called_once_with(1, kwarg="test")