Support kwargs input for `LayerSummary` (#17709)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Leng Yue 2023-05-30 00:17:50 -04:00 committed by GitHub
parent c3ad7568e1
commit 6f4524a25c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 24 additions and 2 deletions

View File

@ -88,6 +88,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670))
- Support kwargs input for LayerSummary ([#17709](https://github.com/Lightning-AI/lightning/pull/17709))
### Deprecated
- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))

View File

@ -25,6 +25,7 @@ from torch import Tensor
from torch.utils.hooks import RemovableHandle
import lightning.pytorch as pl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch.utilities.rank_zero import WarningCache
log = logging.getLogger(__name__)
@ -89,14 +90,26 @@ class LayerSummary:
def hook(_: nn.Module, inp: Any, out: Any) -> None:
if len(inp) == 1:
inp = inp[0]
self._in_size = parse_batch_shape(inp)
self._out_size = parse_batch_shape(out)
assert self._hook_handle is not None
self._hook_handle.remove()
def hook_with_kwargs(_: nn.Module, args: Any, kwargs: Any, out: Any) -> None:
# We can't write them in the same function, since the forward hook
# uses positional arguments.
inp = (*args, *kwargs.values()) if kwargs is not None else args
hook(_, inp, out)
handle = None
if not isinstance(self._module, torch.jit.ScriptModule):
handle = self._module.register_forward_hook(hook)
if _TORCH_GREATER_EQUAL_2_0:
handle = self._module.register_forward_hook(hook_with_kwargs, with_kwargs=True)
else:
handle = self._module.register_forward_hook(hook)
return handle
def detach_hook(self) -> None:

View File

@ -18,6 +18,7 @@ import pytest
import torch
import torch.nn as nn
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.model_summary.model_summary import (
@ -270,7 +271,8 @@ def test_summary_with_scripted_modules(max_depth):
([], UNKNOWN_SIZE),
((1, 2, 3), [UNKNOWN_SIZE] * 3),
(torch.tensor(0), UNKNOWN_SIZE),
({"tensor": torch.zeros(1, 2, 3)}, UNKNOWN_SIZE),
({"tensor": torch.zeros(1, 2, 3)}, [1, 2, 3]),
({"tensor0": torch.zeros(1, 2, 3), "tensor1": torch.zeros(4, 5, 6)}, [[1, 2, 3], [4, 5, 6]]),
(torch.zeros(2, 3, 4), [2, 3, 4]),
([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]),
((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]),
@ -292,6 +294,10 @@ def test_example_input_array_types(example_input, expected_size, max_depth):
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
if isinstance(example_input, dict) and not _TORCH_GREATER_EQUAL_2_0:
# kwargs are not supported when torch < 2.0
expected_size = UNKNOWN_SIZE
model = DummyLightningModule()
model.example_input_array = example_input
summary = summarize(model, max_depth=max_depth)