Enable `Fabric.call` to call hooks on the LightningModule (#17874)

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
M. Fox 2023-06-21 06:25:05 -07:00 committed by GitHub
parent ba943e3d6e
commit bf2af0e591
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 0 deletions

View File

@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
- Added support for calling hooks on a LightningModule via `Fabric.call` ([#17874](https://github.com/Lightning-AI/lightning/pull/17874))
- Added the parameter `Fabric.load(..., strict=True|False)` to enable non-strict loading of partial checkpoint state ([#17645](https://github.com/Lightning-AI/lightning/pull/17645))

View File

@ -230,6 +230,8 @@ class Fabric:
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self # type: ignore[assignment]
original_module._fabric_optimizers = optimizers # type: ignore[assignment]
if original_module not in self._callbacks:
self._callbacks.append(original_module)
self.call("on_after_setup", fabric=self, module=module)
@ -271,6 +273,8 @@ class Fabric:
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
original_module._fabric = self # type: ignore[assignment]
if original_module not in self._callbacks:
self._callbacks.append(original_module)
self._models_setup += 1
return module

View File

@ -13,6 +13,7 @@
# limitations under the License.
from copy import deepcopy
from unittest.mock import Mock
import torch
@ -58,3 +59,27 @@ def test_fabric_boring_lightning_module_manual():
model.training_step(batch, 0) # .backward() and optimizer.step() happen inside training_step()
assert all(not torch.equal(before, after) for before, after in zip(parameters_before, model.parameters()))
def test_fabric_call_lightning_module_hooks():
"""Test that `Fabric.call` can call hooks on the LightningModule."""
class HookedModel(BoringModel):
def on_train_start(self):
pass
def on_my_custom_hook(self, arg, kwarg=None):
pass
fabric = Fabric(accelerator="cpu", devices=1)
module = Mock(wraps=HookedModel())
_ = fabric.setup(module)
_ = fabric.setup(module) # shouldn't add module to callbacks a second time
assert fabric._callbacks == [module]
fabric.call("on_train_start")
module.on_train_start.assert_called_once_with()
fabric.call("on_my_custom_hook", 1, kwarg="test")
module.on_my_custom_hook.assert_called_once_with(1, kwarg="test")