Add `ModelSummary.max_depth` (#8062)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Palermo 2021-07-01 11:08:16 +01:00 committed by GitHub
parent 3c74502919
commit 36b893c43e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 140 additions and 36 deletions

View File

@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014))
- Added `max_depth` parameter in `ModelSummary` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))
### Changed
@ -270,6 +273,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#8025](https://github.com/PyTorchLightning/pytorch-lightning/pull/8025))
- Deprecated `mode` parameter in `ModelSummary` in favor of `max_depth` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))
### Removed
- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))

View File

@ -1642,15 +1642,24 @@ class LightningModule(
return splits
def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]:
def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None) -> Optional[ModelSummary]:
model_summary = None
if mode in ModelSummary.MODES:
model_summary = ModelSummary(self, mode=mode)
log.info("\n" + str(model_summary))
elif mode is not None:
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
# temporary mapping from mode to max_depth
if max_depth is None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
rank_zero_deprecation(
f"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
)
model_summary = ModelSummary(self, max_depth=max_depth)
elif mode is not None:
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
else:
model_summary = ModelSummary(self, max_depth=max_depth)
log.info("\n" + str(model_summary))
return model_summary
def freeze(self) -> None:

View File

@ -131,11 +131,17 @@ class ModelSummary(object):
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
Args:
model: The model to summarize (also referred to as the root module)
model: The model to summarize (also referred to as the root module).
mode: Can be one of
- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
.. deprecated:: v1.4
This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6.
max_depth: Maximum depth of modules to show. Use -1 to show all modules or 0 to show no
summary. Defaults to 1.
The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.
@ -160,7 +166,7 @@ class ModelSummary(object):
... return self.net(x)
...
>>> model = LitModel()
>>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
>>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
@ -169,7 +175,7 @@ class ModelSummary(object):
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
>>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
@ -182,14 +188,28 @@ class ModelSummary(object):
0.530 Total estimated model params size (MB)
"""
MODE_TOP = "top"
MODE_FULL = "full"
MODE_DEFAULT = MODE_TOP
MODES = [MODE_FULL, MODE_TOP]
MODES = dict(top=1, full=-1) # TODO: remove in v1.6
def __init__(self, model, mode: str = MODE_DEFAULT):
def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1):
self._model = model
self._mode = mode
# temporary mapping from mode to max_depth
if max_depth is None or mode is not None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
from pytorch_lightning.utilities import rank_zero_deprecation
rank_zero_deprecation(
f"Argument `mode` in `ModelSummary` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
)
else:
from pytorch_lightning.utilities.exceptions import MisconfigurationException
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
if not isinstance(max_depth, int) or max_depth < -1:
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
self._max_depth = max_depth
self._layer_summary = self.summarize()
# 1 byte -> 8 bits
# TODO: how do we compute precisin_megabytes in case of mixed precision?
@ -198,14 +218,14 @@ class ModelSummary(object):
@property
def named_modules(self) -> List[Tuple[str, nn.Module]]:
if self._mode == ModelSummary.MODE_FULL:
mods = self._model.named_modules()
mods = list(mods)[1:] # do not include root module (LightningModule)
elif self._mode == ModelSummary.MODE_TOP:
if self._max_depth == 0:
mods = []
elif self._max_depth == 1:
# the children are the top-level modules
mods = self._model.named_children()
else:
mods = []
mods = self._model.named_modules()
mods = list(mods)[1:] # do not include root module (LightningModule)
return list(mods)
@property
@ -249,6 +269,12 @@ class ModelSummary(object):
self._forward_example_input()
for layer in summary.values():
layer.detach_hook()
if self._max_depth >= 1:
# remove summary entries with depth > max_depth
for k in [k for k in summary if k.count(".") >= self._max_depth]:
del summary[k]
return summary
def _forward_example_input(self) -> None:

View File

@ -941,7 +941,8 @@ class Trainer(
# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
ref_model.summarize(mode=self.weights_summary)
max_depth = ModelSummary.MODES[self.weights_summary]
ref_model.summarize(max_depth=max_depth)
# on pretrain routine end
self.on_pretrain_routine_end()

View File

@ -114,6 +114,29 @@ class LazyModel(LightningModule):
return self.layer2(self.layer1(inp))
class DeepNestedModel(LightningModule):
""" A model with deep nested layers. """
def __init__(self):
super().__init__()
self.branch1 = nn.Sequential(
nn.Linear(5, 5),
nn.Sequential(
nn.Linear(5, 5),
nn.Sequential(
nn.Linear(5, 5),
nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 3))))
)
)
)
self.branch2 = nn.Linear(5, 10)
self.head = UnorderedModel()
self.example_input_array = torch.rand(2, 5)
def forward(self, inp):
return self.head(self.branch1(inp), self.branch2(inp))
def test_invalid_weights_summmary():
""" Test that invalid value for weights_summary raises an error. """
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
@ -123,8 +146,8 @@ def test_invalid_weights_summmary():
Trainer(weights_summary='temp')
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
def test_empty_model_summary_shapes(mode: ModelSummary):
@pytest.mark.parametrize('mode', ["full", "top"])
def test_empty_model_summary_shapes(mode: str):
""" Test that the summary works for models that have no submodules. """
model = EmptyModule()
summary = model.summarize(mode=mode)
@ -134,7 +157,7 @@ def test_empty_model_summary_shapes(mode: ModelSummary):
@RunIf(min_gpus=1)
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
@pytest.mark.parametrize(['device'], [
pytest.param(torch.device('cpu')),
pytest.param(torch.device('cuda', 0)),
@ -177,18 +200,18 @@ def test_mixed_dtype_model_summary():
]
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
def test_hooks_removed_after_summarize(mode):
@pytest.mark.parametrize('max_depth', [-1, 0])
def test_hooks_removed_after_summarize(max_depth):
""" Test that all hooks were properly removed after summary, even ones that were not run. """
model = UnorderedModel()
summary = ModelSummary(model, mode=mode)
summary = ModelSummary(model, max_depth=max_depth)
# hooks should be removed
for _, layer in summary.summarize().items():
handle = layer._hook_handle
assert handle.id not in handle.hooks_dict_ref()
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_rnn_summary_shapes(mode):
""" Test that the model summary works for RNNs. """
model = ParityModuleRNN()
@ -212,7 +235,7 @@ def test_rnn_summary_shapes(mode):
]
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_summary_parameter_count(mode):
""" Test that the summary counts the number of parameters in every submodule. """
model = UnorderedModel()
@ -226,7 +249,7 @@ def test_summary_parameter_count(mode):
]
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_summary_layer_types(mode):
""" Test that the summary displays the layer names correctly. """
model = UnorderedModel()
@ -240,7 +263,7 @@ def test_summary_layer_types(mode):
]
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_summary_with_scripted_modules(mode):
model = PartialScriptModel()
summary = model.summarize(mode=mode)
@ -249,7 +272,7 @@ def test_summary_with_scripted_modules(mode):
assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]]
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
@pytest.mark.parametrize(['example_input', 'expected_size'], [
pytest.param([], UNKNOWN_SIZE),
pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3),
@ -283,7 +306,7 @@ def test_example_input_array_types(example_input, expected_size, mode):
assert summary.in_sizes == [expected_size]
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_model_size(mode):
""" Test model size is calculated correctly. """
model = PreCalculatedModel()
@ -291,7 +314,7 @@ def test_model_size(mode):
assert model.pre_calculated_model_size == summary.model_size
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
@pytest.mark.parametrize('mode', ["full", "top"])
def test_empty_model_size(mode):
""" Test empty model size is zero. """
model = EmptyModule()
@ -336,3 +359,32 @@ def test_lazy_model_summary():
# https://github.com/pytorch/pytorch/issues/58350
assert summary.total_parameters == 7
assert summary.trainable_parameters == 7
def test_max_depth_equals_mode_interface():
"""Test model.summarize(full/top) interface mapping matches max_depth"""
model = DeepNestedModel()
summary_top = model.summarize(mode="top")
summary_0 = model.summarize(max_depth=1)
assert str(summary_top) == str(summary_0)
summary_full = model.summarize(mode="full")
summary_minus1 = model.summarize(max_depth=-1)
assert str(summary_full) == str(summary_minus1)
@pytest.mark.parametrize('max_depth', [-1, 0, 1, 3, 999])
def test_max_depth_param(max_depth):
"""Test that only the modules up to the desired depth are shown"""
model = DeepNestedModel()
summary = ModelSummary(model, max_depth=max_depth)
for lname in summary.layer_names:
if max_depth >= 0:
assert lname.count(".") < max_depth
@pytest.mark.parametrize('max_depth', [-99, -2, "invalid"])
def test_raise_invalid_max_depth_value(max_depth):
with pytest.raises(ValueError, match=f"`max_depth` can be -1, 0 or > 0, got {max_depth}"):
DeepNestedModel().summarize(max_depth=max_depth)

View File

@ -16,6 +16,7 @@ import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -249,3 +250,12 @@ def test_v1_6_0_ddp_plugin_task_idx():
plugin = DDPPlugin()
with pytest.deprecated_call(match='Use `DDPPlugin.local_rank` instead'):
_ = plugin.task_idx
def test_v1_6_0_deprecated_model_summary_mode(tmpdir):
model = BoringModel()
with pytest.deprecated_call(match="Argument `mode` in `ModelSummary` is deprecated in v1.4"):
ModelSummary(model, mode="top")
with pytest.deprecated_call(match="Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"):
model.summarize(mode="top")