Fix summary hook handles not getting removed (#2298)
* detach hooks after completion * detach hook * update docs * add test * docs * changelog
This commit is contained in:
parent
c7f8367650
commit
f972ab3a82
|
@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))
|
||||
|
||||
|
||||
## [0.8.1] - 2020-06-19
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Tuple, Dict, Union, List, Any
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -54,11 +55,17 @@ class LayerSummary(object):
|
|||
self._in_size = None
|
||||
self._out_size = None
|
||||
|
||||
def _register_hook(self):
|
||||
def __del__(self):
|
||||
self.detach_hook()
|
||||
|
||||
def _register_hook(self) -> RemovableHandle:
|
||||
"""
|
||||
Registers a hook on the module that computes the input- and output size(s)
|
||||
on the first forward pass. The hook will remove itself from the module, meaning that
|
||||
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
|
||||
If the hook is called, it will remove itself from the from the module, meaning that
|
||||
recursive models will only record their input- and output shapes once.
|
||||
|
||||
Return:
|
||||
A handle for the installed hook.
|
||||
"""
|
||||
|
||||
def hook(module, inp, out):
|
||||
|
@ -66,16 +73,24 @@ class LayerSummary(object):
|
|||
inp = inp[0]
|
||||
self._in_size = parse_batch_shape(inp)
|
||||
self._out_size = parse_batch_shape(out)
|
||||
self._hook_handle.remove() # hook detaches itself from module
|
||||
self._hook_handle.remove()
|
||||
|
||||
return self._module.register_forward_hook(hook)
|
||||
|
||||
def detach_hook(self):
|
||||
"""
|
||||
Removes the forward hook if it was not already removed in the forward pass.
|
||||
Will be called after the summary is created.
|
||||
"""
|
||||
if self._hook_handle is not None:
|
||||
self._hook_handle.remove()
|
||||
|
||||
@property
|
||||
def in_size(self):
|
||||
def in_size(self) -> Union[str, List]:
|
||||
return self._in_size or UNKNOWN_SIZE
|
||||
|
||||
@property
|
||||
def out_size(self):
|
||||
def out_size(self) -> Union[str, List]:
|
||||
return self._out_size or UNKNOWN_SIZE
|
||||
|
||||
@property
|
||||
|
@ -180,6 +195,8 @@ class ModelSummary(object):
|
|||
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
|
||||
if self._model.example_input_array is not None:
|
||||
self._forward_example_input()
|
||||
for layer in summary.values():
|
||||
layer.detach_hook()
|
||||
return summary
|
||||
|
||||
def _forward_example_input(self) -> None:
|
||||
|
|
|
@ -89,6 +89,20 @@ def test_linear_model_summary_shapes(device, dtype, mode):
|
|||
assert model.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['mode'], [
|
||||
pytest.param(ModelSummary.MODE_FULL),
|
||||
pytest.param(ModelSummary.MODE_TOP),
|
||||
])
|
||||
def test_hooks_removed_after_summarize(mode):
|
||||
""" Test that all hooks were properly removed after summary, even ones that were not run. """
|
||||
model = UnorderedModel()
|
||||
summary = ModelSummary(model, mode=mode)
|
||||
# hooks should be removed
|
||||
for _, layer in summary.summarize().items():
|
||||
handle = layer._hook_handle
|
||||
assert handle.id not in handle.hooks_dict_ref()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['mode'], [
|
||||
pytest.param(ModelSummary.MODE_FULL),
|
||||
pytest.param(ModelSummary.MODE_TOP),
|
||||
|
|
Loading…
Reference in New Issue