Add `ModelSummary.max_depth` (#8062)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
3c74502919
commit
36b893c43e
|
@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014))
|
||||
|
||||
|
||||
- Added `max_depth` parameter in `ModelSummary` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
|
@ -270,6 +273,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#8025](https://github.com/PyTorchLightning/pytorch-lightning/pull/8025))
|
||||
|
||||
|
||||
- Deprecated `mode` parameter in `ModelSummary` in favor of `max_depth` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
|
||||
|
|
|
@ -1642,15 +1642,24 @@ class LightningModule(
|
|||
|
||||
return splits
|
||||
|
||||
def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]:
|
||||
def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None) -> Optional[ModelSummary]:
|
||||
model_summary = None
|
||||
|
||||
if mode in ModelSummary.MODES:
|
||||
model_summary = ModelSummary(self, mode=mode)
|
||||
log.info("\n" + str(model_summary))
|
||||
elif mode is not None:
|
||||
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
|
||||
# temporary mapping from mode to max_depth
|
||||
if max_depth is None:
|
||||
if mode in ModelSummary.MODES:
|
||||
max_depth = ModelSummary.MODES[mode]
|
||||
rank_zero_deprecation(
|
||||
f"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
|
||||
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
|
||||
)
|
||||
model_summary = ModelSummary(self, max_depth=max_depth)
|
||||
elif mode is not None:
|
||||
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
|
||||
else:
|
||||
model_summary = ModelSummary(self, max_depth=max_depth)
|
||||
|
||||
log.info("\n" + str(model_summary))
|
||||
return model_summary
|
||||
|
||||
def freeze(self) -> None:
|
||||
|
|
|
@ -131,11 +131,17 @@ class ModelSummary(object):
|
|||
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
||||
|
||||
Args:
|
||||
model: The model to summarize (also referred to as the root module)
|
||||
model: The model to summarize (also referred to as the root module).
|
||||
mode: Can be one of
|
||||
|
||||
- `top` (default): only the top-level modules will be recorded (the children of the root module)
|
||||
- `full`: summarizes all layers and their submodules in the root module
|
||||
- `top` (default): only the top-level modules will be recorded (the children of the root module)
|
||||
- `full`: summarizes all layers and their submodules in the root module
|
||||
|
||||
.. deprecated:: v1.4
|
||||
This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6.
|
||||
|
||||
max_depth: Maximum depth of modules to show. Use -1 to show all modules or 0 to show no
|
||||
summary. Defaults to 1.
|
||||
|
||||
The string representation of this summary prints a table with columns containing
|
||||
the name, type and number of parameters for each layer.
|
||||
|
@ -160,7 +166,7 @@ class ModelSummary(object):
|
|||
... return self.net(x)
|
||||
...
|
||||
>>> model = LitModel()
|
||||
>>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE
|
||||
>>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
|
||||
| Name | Type | Params | In sizes | Out sizes
|
||||
------------------------------------------------------------
|
||||
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
|
||||
|
@ -169,7 +175,7 @@ class ModelSummary(object):
|
|||
0 Non-trainable params
|
||||
132 K Total params
|
||||
0.530 Total estimated model params size (MB)
|
||||
>>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
|
||||
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
|
||||
| Name | Type | Params | In sizes | Out sizes
|
||||
--------------------------------------------------------------
|
||||
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
|
||||
|
@ -182,14 +188,28 @@ class ModelSummary(object):
|
|||
0.530 Total estimated model params size (MB)
|
||||
"""
|
||||
|
||||
MODE_TOP = "top"
|
||||
MODE_FULL = "full"
|
||||
MODE_DEFAULT = MODE_TOP
|
||||
MODES = [MODE_FULL, MODE_TOP]
|
||||
MODES = dict(top=1, full=-1) # TODO: remove in v1.6
|
||||
|
||||
def __init__(self, model, mode: str = MODE_DEFAULT):
|
||||
def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1):
|
||||
self._model = model
|
||||
self._mode = mode
|
||||
|
||||
# temporary mapping from mode to max_depth
|
||||
if max_depth is None or mode is not None:
|
||||
if mode in ModelSummary.MODES:
|
||||
max_depth = ModelSummary.MODES[mode]
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
rank_zero_deprecation(
|
||||
f"Argument `mode` in `ModelSummary` is deprecated in v1.4"
|
||||
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
|
||||
)
|
||||
else:
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
|
||||
|
||||
if not isinstance(max_depth, int) or max_depth < -1:
|
||||
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
|
||||
|
||||
self._max_depth = max_depth
|
||||
self._layer_summary = self.summarize()
|
||||
# 1 byte -> 8 bits
|
||||
# TODO: how do we compute precisin_megabytes in case of mixed precision?
|
||||
|
@ -198,14 +218,14 @@ class ModelSummary(object):
|
|||
|
||||
@property
|
||||
def named_modules(self) -> List[Tuple[str, nn.Module]]:
|
||||
if self._mode == ModelSummary.MODE_FULL:
|
||||
mods = self._model.named_modules()
|
||||
mods = list(mods)[1:] # do not include root module (LightningModule)
|
||||
elif self._mode == ModelSummary.MODE_TOP:
|
||||
if self._max_depth == 0:
|
||||
mods = []
|
||||
elif self._max_depth == 1:
|
||||
# the children are the top-level modules
|
||||
mods = self._model.named_children()
|
||||
else:
|
||||
mods = []
|
||||
mods = self._model.named_modules()
|
||||
mods = list(mods)[1:] # do not include root module (LightningModule)
|
||||
return list(mods)
|
||||
|
||||
@property
|
||||
|
@ -249,6 +269,12 @@ class ModelSummary(object):
|
|||
self._forward_example_input()
|
||||
for layer in summary.values():
|
||||
layer.detach_hook()
|
||||
|
||||
if self._max_depth >= 1:
|
||||
# remove summary entries with depth > max_depth
|
||||
for k in [k for k in summary if k.count(".") >= self._max_depth]:
|
||||
del summary[k]
|
||||
|
||||
return summary
|
||||
|
||||
def _forward_example_input(self) -> None:
|
||||
|
|
|
@ -941,7 +941,8 @@ class Trainer(
|
|||
|
||||
# print model summary
|
||||
if self.is_global_zero and self.weights_summary is not None and not self.testing:
|
||||
ref_model.summarize(mode=self.weights_summary)
|
||||
max_depth = ModelSummary.MODES[self.weights_summary]
|
||||
ref_model.summarize(max_depth=max_depth)
|
||||
|
||||
# on pretrain routine end
|
||||
self.on_pretrain_routine_end()
|
||||
|
|
|
@ -114,6 +114,29 @@ class LazyModel(LightningModule):
|
|||
return self.layer2(self.layer1(inp))
|
||||
|
||||
|
||||
class DeepNestedModel(LightningModule):
|
||||
""" A model with deep nested layers. """
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.branch1 = nn.Sequential(
|
||||
nn.Linear(5, 5),
|
||||
nn.Sequential(
|
||||
nn.Linear(5, 5),
|
||||
nn.Sequential(
|
||||
nn.Linear(5, 5),
|
||||
nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 3))))
|
||||
)
|
||||
)
|
||||
)
|
||||
self.branch2 = nn.Linear(5, 10)
|
||||
self.head = UnorderedModel()
|
||||
self.example_input_array = torch.rand(2, 5)
|
||||
|
||||
def forward(self, inp):
|
||||
return self.head(self.branch1(inp), self.branch2(inp))
|
||||
|
||||
|
||||
def test_invalid_weights_summmary():
|
||||
""" Test that invalid value for weights_summary raises an error. """
|
||||
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
|
||||
|
@ -123,8 +146,8 @@ def test_invalid_weights_summmary():
|
|||
Trainer(weights_summary='temp')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
def test_empty_model_summary_shapes(mode: ModelSummary):
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
def test_empty_model_summary_shapes(mode: str):
|
||||
""" Test that the summary works for models that have no submodules. """
|
||||
model = EmptyModule()
|
||||
summary = model.summarize(mode=mode)
|
||||
|
@ -134,7 +157,7 @@ def test_empty_model_summary_shapes(mode: ModelSummary):
|
|||
|
||||
|
||||
@RunIf(min_gpus=1)
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
@pytest.mark.parametrize(['device'], [
|
||||
pytest.param(torch.device('cpu')),
|
||||
pytest.param(torch.device('cuda', 0)),
|
||||
|
@ -177,18 +200,18 @@ def test_mixed_dtype_model_summary():
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
def test_hooks_removed_after_summarize(mode):
|
||||
@pytest.mark.parametrize('max_depth', [-1, 0])
|
||||
def test_hooks_removed_after_summarize(max_depth):
|
||||
""" Test that all hooks were properly removed after summary, even ones that were not run. """
|
||||
model = UnorderedModel()
|
||||
summary = ModelSummary(model, mode=mode)
|
||||
summary = ModelSummary(model, max_depth=max_depth)
|
||||
# hooks should be removed
|
||||
for _, layer in summary.summarize().items():
|
||||
handle = layer._hook_handle
|
||||
assert handle.id not in handle.hooks_dict_ref()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
def test_rnn_summary_shapes(mode):
|
||||
""" Test that the model summary works for RNNs. """
|
||||
model = ParityModuleRNN()
|
||||
|
@ -212,7 +235,7 @@ def test_rnn_summary_shapes(mode):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
def test_summary_parameter_count(mode):
|
||||
""" Test that the summary counts the number of parameters in every submodule. """
|
||||
model = UnorderedModel()
|
||||
|
@ -226,7 +249,7 @@ def test_summary_parameter_count(mode):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
def test_summary_layer_types(mode):
|
||||
""" Test that the summary displays the layer names correctly. """
|
||||
model = UnorderedModel()
|
||||
|
@ -240,7 +263,7 @@ def test_summary_layer_types(mode):
|
|||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
def test_summary_with_scripted_modules(mode):
|
||||
model = PartialScriptModel()
|
||||
summary = model.summarize(mode=mode)
|
||||
|
@ -249,7 +272,7 @@ def test_summary_with_scripted_modules(mode):
|
|||
assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
@pytest.mark.parametrize(['example_input', 'expected_size'], [
|
||||
pytest.param([], UNKNOWN_SIZE),
|
||||
pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3),
|
||||
|
@ -283,7 +306,7 @@ def test_example_input_array_types(example_input, expected_size, mode):
|
|||
assert summary.in_sizes == [expected_size]
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
def test_model_size(mode):
|
||||
""" Test model size is calculated correctly. """
|
||||
model = PreCalculatedModel()
|
||||
|
@ -291,7 +314,7 @@ def test_model_size(mode):
|
|||
assert model.pre_calculated_model_size == summary.model_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP])
|
||||
@pytest.mark.parametrize('mode', ["full", "top"])
|
||||
def test_empty_model_size(mode):
|
||||
""" Test empty model size is zero. """
|
||||
model = EmptyModule()
|
||||
|
@ -336,3 +359,32 @@ def test_lazy_model_summary():
|
|||
# https://github.com/pytorch/pytorch/issues/58350
|
||||
assert summary.total_parameters == 7
|
||||
assert summary.trainable_parameters == 7
|
||||
|
||||
|
||||
def test_max_depth_equals_mode_interface():
|
||||
"""Test model.summarize(full/top) interface mapping matches max_depth"""
|
||||
model = DeepNestedModel()
|
||||
|
||||
summary_top = model.summarize(mode="top")
|
||||
summary_0 = model.summarize(max_depth=1)
|
||||
assert str(summary_top) == str(summary_0)
|
||||
|
||||
summary_full = model.summarize(mode="full")
|
||||
summary_minus1 = model.summarize(max_depth=-1)
|
||||
assert str(summary_full) == str(summary_minus1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('max_depth', [-1, 0, 1, 3, 999])
|
||||
def test_max_depth_param(max_depth):
|
||||
"""Test that only the modules up to the desired depth are shown"""
|
||||
model = DeepNestedModel()
|
||||
summary = ModelSummary(model, max_depth=max_depth)
|
||||
for lname in summary.layer_names:
|
||||
if max_depth >= 0:
|
||||
assert lname.count(".") < max_depth
|
||||
|
||||
|
||||
@pytest.mark.parametrize('max_depth', [-99, -2, "invalid"])
|
||||
def test_raise_invalid_max_depth_value(max_depth):
|
||||
with pytest.raises(ValueError, match=f"`max_depth` can be -1, 0 or > 0, got {max_depth}"):
|
||||
DeepNestedModel().summarize(max_depth=max_depth)
|
||||
|
|
|
@ -16,6 +16,7 @@ import pytest
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -249,3 +250,12 @@ def test_v1_6_0_ddp_plugin_task_idx():
|
|||
plugin = DDPPlugin()
|
||||
with pytest.deprecated_call(match='Use `DDPPlugin.local_rank` instead'):
|
||||
_ = plugin.task_idx
|
||||
|
||||
|
||||
def test_v1_6_0_deprecated_model_summary_mode(tmpdir):
|
||||
model = BoringModel()
|
||||
with pytest.deprecated_call(match="Argument `mode` in `ModelSummary` is deprecated in v1.4"):
|
||||
ModelSummary(model, mode="top")
|
||||
|
||||
with pytest.deprecated_call(match="Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"):
|
||||
model.summarize(mode="top")
|
||||
|
|
Loading…
Reference in New Issue