Enable `Fabric.call` to call hooks on the LightningModule (#17874)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
ba943e3d6e
commit
bf2af0e591
|
@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
|
||||
|
||||
|
||||
- Added support for calling hooks on a LightningModule via `Fabric.call` ([#17874](https://github.com/Lightning-AI/lightning/pull/17874))
|
||||
|
||||
|
||||
- Added the parameter `Fabric.load(..., strict=True|False)` to enable non-strict loading of partial checkpoint state ([#17645](https://github.com/Lightning-AI/lightning/pull/17645))
|
||||
|
||||
|
||||
|
|
|
@ -230,6 +230,8 @@ class Fabric:
|
|||
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
|
||||
original_module._fabric = self # type: ignore[assignment]
|
||||
original_module._fabric_optimizers = optimizers # type: ignore[assignment]
|
||||
if original_module not in self._callbacks:
|
||||
self._callbacks.append(original_module)
|
||||
|
||||
self.call("on_after_setup", fabric=self, module=module)
|
||||
|
||||
|
@ -271,6 +273,8 @@ class Fabric:
|
|||
|
||||
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
|
||||
original_module._fabric = self # type: ignore[assignment]
|
||||
if original_module not in self._callbacks:
|
||||
self._callbacks.append(original_module)
|
||||
|
||||
self._models_setup += 1
|
||||
return module
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -58,3 +59,27 @@ def test_fabric_boring_lightning_module_manual():
|
|||
model.training_step(batch, 0) # .backward() and optimizer.step() happen inside training_step()
|
||||
|
||||
assert all(not torch.equal(before, after) for before, after in zip(parameters_before, model.parameters()))
|
||||
|
||||
|
||||
def test_fabric_call_lightning_module_hooks():
|
||||
"""Test that `Fabric.call` can call hooks on the LightningModule."""
|
||||
|
||||
class HookedModel(BoringModel):
|
||||
def on_train_start(self):
|
||||
pass
|
||||
|
||||
def on_my_custom_hook(self, arg, kwarg=None):
|
||||
pass
|
||||
|
||||
fabric = Fabric(accelerator="cpu", devices=1)
|
||||
module = Mock(wraps=HookedModel())
|
||||
|
||||
_ = fabric.setup(module)
|
||||
_ = fabric.setup(module) # shouldn't add module to callbacks a second time
|
||||
assert fabric._callbacks == [module]
|
||||
|
||||
fabric.call("on_train_start")
|
||||
module.on_train_start.assert_called_once_with()
|
||||
|
||||
fabric.call("on_my_custom_hook", 1, kwarg="test")
|
||||
module.on_my_custom_hook.assert_called_once_with(1, kwarg="test")
|
||||
|
|
Loading…
Reference in New Issue