Support kwargs input for `LayerSummary` (#17709)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c3ad7568e1
commit
6f4524a25c
|
@ -88,6 +88,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670))
|
||||
|
||||
|
||||
- Support kwargs input for LayerSummary ([#17709](https://github.com/Lightning-AI/lightning/pull/17709))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))
|
||||
|
|
|
@ -25,6 +25,7 @@ from torch import Tensor
|
|||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning.pytorch.utilities.rank_zero import WarningCache
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -89,14 +90,26 @@ class LayerSummary:
|
|||
def hook(_: nn.Module, inp: Any, out: Any) -> None:
|
||||
if len(inp) == 1:
|
||||
inp = inp[0]
|
||||
|
||||
self._in_size = parse_batch_shape(inp)
|
||||
self._out_size = parse_batch_shape(out)
|
||||
assert self._hook_handle is not None
|
||||
self._hook_handle.remove()
|
||||
|
||||
def hook_with_kwargs(_: nn.Module, args: Any, kwargs: Any, out: Any) -> None:
|
||||
# We can't write them in the same function, since the forward hook
|
||||
# uses positional arguments.
|
||||
|
||||
inp = (*args, *kwargs.values()) if kwargs is not None else args
|
||||
hook(_, inp, out)
|
||||
|
||||
handle = None
|
||||
if not isinstance(self._module, torch.jit.ScriptModule):
|
||||
handle = self._module.register_forward_hook(hook)
|
||||
if _TORCH_GREATER_EQUAL_2_0:
|
||||
handle = self._module.register_forward_hook(hook_with_kwargs, with_kwargs=True)
|
||||
else:
|
||||
handle = self._module.register_forward_hook(hook)
|
||||
|
||||
return handle
|
||||
|
||||
def detach_hook(self) -> None:
|
||||
|
|
|
@ -18,6 +18,7 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning.pytorch import LightningModule, Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.utilities.model_summary.model_summary import (
|
||||
|
@ -270,7 +271,8 @@ def test_summary_with_scripted_modules(max_depth):
|
|||
([], UNKNOWN_SIZE),
|
||||
((1, 2, 3), [UNKNOWN_SIZE] * 3),
|
||||
(torch.tensor(0), UNKNOWN_SIZE),
|
||||
({"tensor": torch.zeros(1, 2, 3)}, UNKNOWN_SIZE),
|
||||
({"tensor": torch.zeros(1, 2, 3)}, [1, 2, 3]),
|
||||
({"tensor0": torch.zeros(1, 2, 3), "tensor1": torch.zeros(4, 5, 6)}, [[1, 2, 3], [4, 5, 6]]),
|
||||
(torch.zeros(2, 3, 4), [2, 3, 4]),
|
||||
([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]),
|
||||
((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]),
|
||||
|
@ -292,6 +294,10 @@ def test_example_input_array_types(example_input, expected_size, max_depth):
|
|||
def forward(self, *args, **kwargs):
|
||||
return self.layer(*args, **kwargs)
|
||||
|
||||
if isinstance(example_input, dict) and not _TORCH_GREATER_EQUAL_2_0:
|
||||
# kwargs are not supported when torch < 2.0
|
||||
expected_size = UNKNOWN_SIZE
|
||||
|
||||
model = DummyLightningModule()
|
||||
model.example_input_array = example_input
|
||||
summary = summarize(model, max_depth=max_depth)
|
||||
|
|
Loading…
Reference in New Issue