Handle torch.jit scripted modules in layer summary (#6511)
This commit is contained in:
parent
0544efd453
commit
02fa32b7bc
|
@ -140,6 +140,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
|
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
|
||||||
|
|
||||||
|
|
||||||
|
- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511))
|
||||||
|
|
||||||
|
|
||||||
## [1.2.3] - 2021-03-09
|
## [1.2.3] - 2021-03-09
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
|
@ -16,7 +16,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -71,14 +71,15 @@ class LayerSummary(object):
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.detach_hook()
|
self.detach_hook()
|
||||||
|
|
||||||
def _register_hook(self) -> RemovableHandle:
|
def _register_hook(self) -> Optional[RemovableHandle]:
|
||||||
"""
|
"""
|
||||||
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
|
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
|
||||||
If the hook is called, it will remove itself from the from the module, meaning that
|
If the hook is called, it will remove itself from the from the module, meaning that
|
||||||
recursive models will only record their input- and output shapes once.
|
recursive models will only record their input- and output shapes once.
|
||||||
|
Registering hooks on :class:`~torch.jit.ScriptModule` is not supported.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A handle for the installed hook.
|
A handle for the installed hook, or ``None`` if registering the hook is not possible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def hook(module, inp, out):
|
def hook(module, inp, out):
|
||||||
|
@ -88,7 +89,10 @@ class LayerSummary(object):
|
||||||
self._out_size = parse_batch_shape(out)
|
self._out_size = parse_batch_shape(out)
|
||||||
self._hook_handle.remove()
|
self._hook_handle.remove()
|
||||||
|
|
||||||
return self._module.register_forward_hook(hook)
|
handle = None
|
||||||
|
if not isinstance(self._module, torch.jit.ScriptModule):
|
||||||
|
handle = self._module.register_forward_hook(hook)
|
||||||
|
return handle
|
||||||
|
|
||||||
def detach_hook(self):
|
def detach_hook(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -88,6 +88,19 @@ class MixedDtypeModel(LightningModule):
|
||||||
return self.reduce(self.embed(x))
|
return self.reduce(self.embed(x))
|
||||||
|
|
||||||
|
|
||||||
|
class PartialScriptModel(LightningModule):
|
||||||
|
""" A model which contains scripted layers. """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = torch.jit.script(nn.Linear(5, 3))
|
||||||
|
self.layer2 = nn.Linear(3, 2)
|
||||||
|
self.example_input_array = torch.rand(2, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layer2(self.layer1(x))
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_weights_summmary():
|
def test_invalid_weights_summmary():
|
||||||
""" Test that invalid value for weights_summary raises an error. """
|
""" Test that invalid value for weights_summary raises an error. """
|
||||||
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
|
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
|
||||||
|
@ -214,6 +227,15 @@ def test_summary_layer_types(mode):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||||
|
def test_summary_with_scripted_modules(mode):
|
||||||
|
model = PartialScriptModel()
|
||||||
|
summary = model.summarize(mode=mode)
|
||||||
|
assert summary.layer_types == ["RecursiveScriptModule", "Linear"]
|
||||||
|
assert summary.in_sizes == [UNKNOWN_SIZE, [2, 3]]
|
||||||
|
assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||||
@pytest.mark.parametrize(['example_input', 'expected_size'], [
|
@pytest.mark.parametrize(['example_input', 'expected_size'], [
|
||||||
pytest.param([], UNKNOWN_SIZE),
|
pytest.param([], UNKNOWN_SIZE),
|
||||||
|
@ -265,7 +287,7 @@ def test_empty_model_size(mode):
|
||||||
|
|
||||||
|
|
||||||
@RunIf(min_gpus=1, amp_native=True)
|
@RunIf(min_gpus=1, amp_native=True)
|
||||||
def test_model_size_precision(monkeypatch, tmpdir):
|
def test_model_size_precision(tmpdir):
|
||||||
""" Test model size for half and full precision. """
|
""" Test model size for half and full precision. """
|
||||||
model = PreCalculatedModel()
|
model = PreCalculatedModel()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue