Fix summary hook handles not getting removed (#2298)

* detach hooks after completion

* detach hook

* update docs

* add test

* docs

* changelog
This commit is contained in:
Adrian Wälchli 2020-06-20 13:38:47 +02:00 committed by GitHub
parent c7f8367650
commit f972ab3a82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 6 deletions

View File

@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))
## [0.8.1] - 2020-06-19

View File

@ -7,6 +7,7 @@ from typing import Tuple, Dict, Union, List, Any
import numpy as np
import torch
import torch.nn as nn
from torch.utils.hooks import RemovableHandle
import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -54,11 +55,17 @@ class LayerSummary(object):
self._in_size = None
self._out_size = None
def _register_hook(self):
def __del__(self):
self.detach_hook()
def _register_hook(self) -> RemovableHandle:
"""
Registers a hook on the module that computes the input- and output size(s)
on the first forward pass. The hook will remove itself from the module, meaning that
Registers a hook on the module that computes the input- and output size(s) on the first forward pass.
If the hook is called, it will remove itself from the from the module, meaning that
recursive models will only record their input- and output shapes once.
Return:
A handle for the installed hook.
"""
def hook(module, inp, out):
@ -66,16 +73,24 @@ class LayerSummary(object):
inp = inp[0]
self._in_size = parse_batch_shape(inp)
self._out_size = parse_batch_shape(out)
self._hook_handle.remove() # hook detaches itself from module
self._hook_handle.remove()
return self._module.register_forward_hook(hook)
def detach_hook(self):
"""
Removes the forward hook if it was not already removed in the forward pass.
Will be called after the summary is created.
"""
if self._hook_handle is not None:
self._hook_handle.remove()
@property
def in_size(self):
def in_size(self) -> Union[str, List]:
return self._in_size or UNKNOWN_SIZE
@property
def out_size(self):
def out_size(self) -> Union[str, List]:
return self._out_size or UNKNOWN_SIZE
@property
@ -180,6 +195,8 @@ class ModelSummary(object):
summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules)
if self._model.example_input_array is not None:
self._forward_example_input()
for layer in summary.values():
layer.detach_hook()
return summary
def _forward_example_input(self) -> None:

View File

@ -89,6 +89,20 @@ def test_linear_model_summary_shapes(device, dtype, mode):
assert model.device == device
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_hooks_removed_after_summarize(mode):
""" Test that all hooks were properly removed after summary, even ones that were not run. """
model = UnorderedModel()
summary = ModelSummary(model, mode=mode)
# hooks should be removed
for _, layer in summary.summarize().items():
handle = layer._hook_handle
assert handle.id not in handle.hooks_dict_ref()
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),