Handle torch.jit scripted modules in layer summary (#6511)
This commit is contained in:
parent
0544efd453
commit
02fa32b7bc
|
@ -140,6 +140,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
|
||||
|
||||
|
||||
- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511))
|
||||
|
||||
|
||||
## [1.2.3] - 2021-03-09
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -16,7 +16,7 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -71,14 +71,15 @@ class LayerSummary(object):
|
|||
def __del__(self):
|
||||
self.detach_hook()
|
||||
|
||||
def _register_hook(self) -> RemovableHandle:
|
||||
def _register_hook(self) -> Optional[RemovableHandle]:
|
||||
"""
|
||||
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
|
||||
If the hook is called, it will remove itself from the from the module, meaning that
|
||||
recursive models will only record their input- and output shapes once.
|
||||
Registering hooks on :class:`~torch.jit.ScriptModule` is not supported.
|
||||
|
||||
Return:
|
||||
A handle for the installed hook.
|
||||
A handle for the installed hook, or ``None`` if registering the hook is not possible.
|
||||
"""
|
||||
|
||||
def hook(module, inp, out):
|
||||
|
@ -88,7 +89,10 @@ class LayerSummary(object):
|
|||
self._out_size = parse_batch_shape(out)
|
||||
self._hook_handle.remove()
|
||||
|
||||
return self._module.register_forward_hook(hook)
|
||||
handle = None
|
||||
if not isinstance(self._module, torch.jit.ScriptModule):
|
||||
handle = self._module.register_forward_hook(hook)
|
||||
return handle
|
||||
|
||||
def detach_hook(self):
|
||||
"""
|
||||
|
|
|
@ -88,6 +88,19 @@ class MixedDtypeModel(LightningModule):
|
|||
return self.reduce(self.embed(x))
|
||||
|
||||
|
||||
class PartialScriptModel(LightningModule):
|
||||
""" A model which contains scripted layers. """
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = torch.jit.script(nn.Linear(5, 3))
|
||||
self.layer2 = nn.Linear(3, 2)
|
||||
self.example_input_array = torch.rand(2, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer2(self.layer1(x))
|
||||
|
||||
|
||||
def test_invalid_weights_summmary():
|
||||
""" Test that invalid value for weights_summary raises an error. """
|
||||
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
|
||||
|
@ -214,6 +227,15 @@ def test_summary_layer_types(mode):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
def test_summary_with_scripted_modules(mode):
|
||||
model = PartialScriptModel()
|
||||
summary = model.summarize(mode=mode)
|
||||
assert summary.layer_types == ["RecursiveScriptModule", "Linear"]
|
||||
assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]]
|
||||
assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize(['example_input', 'expected_size'], [
|
||||
pytest.param([], UNKNOWN_SIZE),
|
||||
|
@ -265,7 +287,7 @@ def test_empty_model_size(mode):
|
|||
|
||||
|
||||
@RunIf(min_gpus=1, amp_native=True)
|
||||
def test_model_size_precision(monkeypatch, tmpdir):
|
||||
def test_model_size_precision(tmpdir):
|
||||
""" Test model size for half and full precision. """
|
||||
model = PreCalculatedModel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue