diff --git a/CHANGELOG.md b/CHANGELOG.md index be7ff7037f..e5f9b75810 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014)) +- Added `max_depth` parameter in `ModelSummary` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062)) + + ### Changed @@ -270,6 +273,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated the `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#8025](https://github.com/PyTorchLightning/pytorch-lightning/pull/8025)) +- Deprecated `mode` parameter in `ModelSummary` in favor of `max_depth` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062)) + + ### Removed - Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ab8263bc8d..b7ce88eace 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1642,15 +1642,24 @@ class LightningModule( return splits - def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]: + def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None) -> Optional[ModelSummary]: model_summary = None - if mode in ModelSummary.MODES: - model_summary = ModelSummary(self, mode=mode) - log.info("\n" + str(model_summary)) - elif mode is not None: - raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") + # temporary mapping from mode to max_depth + if max_depth is None: + if mode in ModelSummary.MODES: + max_depth = ModelSummary.MODES[mode] + rank_zero_deprecation( + f"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4" + f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior." + ) + model_summary = ModelSummary(self, max_depth=max_depth) + elif mode is not None: + raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") + else: + model_summary = ModelSummary(self, max_depth=max_depth) + log.info("\n" + str(model_summary)) return model_summary def freeze(self) -> None: diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 255946ec0c..bba42d6997 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -131,11 +131,17 @@ class ModelSummary(object): Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. Args: - model: The model to summarize (also referred to as the root module) + model: The model to summarize (also referred to as the root module). mode: Can be one of - - `top` (default): only the top-level modules will be recorded (the children of the root module) - - `full`: summarizes all layers and their submodules in the root module + - `top` (default): only the top-level modules will be recorded (the children of the root module) + - `full`: summarizes all layers and their submodules in the root module + + .. deprecated:: v1.4 + This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6. + + max_depth: Maximum depth of modules to show. Use -1 to show all modules or 0 to show no + summary. Defaults to 1. The string representation of this summary prints a table with columns containing the name, type and number of parameters for each layer. @@ -160,7 +166,7 @@ class ModelSummary(object): ... return self.net(x) ... >>> model = LitModel() - >>> ModelSummary(model, mode='top') # doctest: +NORMALIZE_WHITESPACE + >>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE | Name | Type | Params | In sizes | Out sizes ------------------------------------------------------------ 0 | net | Sequential | 132 K | [10, 256] | [10, 512] @@ -169,7 +175,7 @@ class ModelSummary(object): 0 Non-trainable params 132 K Total params 0.530 Total estimated model params size (MB) - >>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE + >>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE | Name | Type | Params | In sizes | Out sizes -------------------------------------------------------------- 0 | net | Sequential | 132 K | [10, 256] | [10, 512] @@ -182,14 +188,28 @@ class ModelSummary(object): 0.530 Total estimated model params size (MB) """ - MODE_TOP = "top" - MODE_FULL = "full" - MODE_DEFAULT = MODE_TOP - MODES = [MODE_FULL, MODE_TOP] + MODES = dict(top=1, full=-1) # TODO: remove in v1.6 - def __init__(self, model, mode: str = MODE_DEFAULT): + def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1): self._model = model - self._mode = mode + + # temporary mapping from mode to max_depth + if max_depth is None or mode is not None: + if mode in ModelSummary.MODES: + max_depth = ModelSummary.MODES[mode] + from pytorch_lightning.utilities import rank_zero_deprecation + rank_zero_deprecation( + f"Argument `mode` in `ModelSummary` is deprecated in v1.4" + f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour." + ) + else: + from pytorch_lightning.utilities.exceptions import MisconfigurationException + raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.") + + if not isinstance(max_depth, int) or max_depth < -1: + raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.") + + self._max_depth = max_depth self._layer_summary = self.summarize() # 1 byte -> 8 bits # TODO: how do we compute precisin_megabytes in case of mixed precision? @@ -198,14 +218,14 @@ class ModelSummary(object): @property def named_modules(self) -> List[Tuple[str, nn.Module]]: - if self._mode == ModelSummary.MODE_FULL: - mods = self._model.named_modules() - mods = list(mods)[1:] # do not include root module (LightningModule) - elif self._mode == ModelSummary.MODE_TOP: + if self._max_depth == 0: + mods = [] + elif self._max_depth == 1: # the children are the top-level modules mods = self._model.named_children() else: - mods = [] + mods = self._model.named_modules() + mods = list(mods)[1:] # do not include root module (LightningModule) return list(mods) @property @@ -249,6 +269,12 @@ class ModelSummary(object): self._forward_example_input() for layer in summary.values(): layer.detach_hook() + + if self._max_depth >= 1: + # remove summary entries with depth > max_depth + for k in [k for k in summary if k.count(".") >= self._max_depth]: + del summary[k] + return summary def _forward_example_input(self) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index dcf04760c6..6e47702e38 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -941,7 +941,8 @@ class Trainer( # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: - ref_model.summarize(mode=self.weights_summary) + max_depth = ModelSummary.MODES[self.weights_summary] + ref_model.summarize(max_depth=max_depth) # on pretrain routine end self.on_pretrain_routine_end() diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index a094ddfc3d..96e1bfaec1 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -114,6 +114,29 @@ class LazyModel(LightningModule): return self.layer2(self.layer1(inp)) +class DeepNestedModel(LightningModule): + """ A model with deep nested layers. """ + + def __init__(self): + super().__init__() + self.branch1 = nn.Sequential( + nn.Linear(5, 5), + nn.Sequential( + nn.Linear(5, 5), + nn.Sequential( + nn.Linear(5, 5), + nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 3)))) + ) + ) + ) + self.branch2 = nn.Linear(5, 10) + self.head = UnorderedModel() + self.example_input_array = torch.rand(2, 5) + + def forward(self, inp): + return self.head(self.branch1(inp), self.branch2(inp)) + + def test_invalid_weights_summmary(): """ Test that invalid value for weights_summary raises an error. """ with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'): @@ -123,8 +146,8 @@ def test_invalid_weights_summmary(): Trainer(weights_summary='temp') -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) -def test_empty_model_summary_shapes(mode: ModelSummary): +@pytest.mark.parametrize('mode', ["full", "top"]) +def test_empty_model_summary_shapes(mode: str): """ Test that the summary works for models that have no submodules. """ model = EmptyModule() summary = model.summarize(mode=mode) @@ -134,7 +157,7 @@ def test_empty_model_summary_shapes(mode: ModelSummary): @RunIf(min_gpus=1) -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) @pytest.mark.parametrize(['device'], [ pytest.param(torch.device('cpu')), pytest.param(torch.device('cuda', 0)), @@ -177,18 +200,18 @@ def test_mixed_dtype_model_summary(): ] -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) -def test_hooks_removed_after_summarize(mode): +@pytest.mark.parametrize('max_depth', [-1, 0]) +def test_hooks_removed_after_summarize(max_depth): """ Test that all hooks were properly removed after summary, even ones that were not run. """ model = UnorderedModel() - summary = ModelSummary(model, mode=mode) + summary = ModelSummary(model, max_depth=max_depth) # hooks should be removed for _, layer in summary.summarize().items(): handle = layer._hook_handle assert handle.id not in handle.hooks_dict_ref() -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) def test_rnn_summary_shapes(mode): """ Test that the model summary works for RNNs. """ model = ParityModuleRNN() @@ -212,7 +235,7 @@ def test_rnn_summary_shapes(mode): ] -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) def test_summary_parameter_count(mode): """ Test that the summary counts the number of parameters in every submodule. """ model = UnorderedModel() @@ -226,7 +249,7 @@ def test_summary_parameter_count(mode): ] -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) def test_summary_layer_types(mode): """ Test that the summary displays the layer names correctly. """ model = UnorderedModel() @@ -240,7 +263,7 @@ def test_summary_layer_types(mode): ] -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) def test_summary_with_scripted_modules(mode): model = PartialScriptModel() summary = model.summarize(mode=mode) @@ -249,7 +272,7 @@ def test_summary_with_scripted_modules(mode): assert summary.out_sizes == [UNKNOWN_SIZE, [2, 2]] -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) @pytest.mark.parametrize(['example_input', 'expected_size'], [ pytest.param([], UNKNOWN_SIZE), pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3), @@ -283,7 +306,7 @@ def test_example_input_array_types(example_input, expected_size, mode): assert summary.in_sizes == [expected_size] -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) def test_model_size(mode): """ Test model size is calculated correctly. """ model = PreCalculatedModel() @@ -291,7 +314,7 @@ def test_model_size(mode): assert model.pre_calculated_model_size == summary.model_size -@pytest.mark.parametrize('mode', [ModelSummary.MODE_FULL, ModelSummary.MODE_TOP]) +@pytest.mark.parametrize('mode', ["full", "top"]) def test_empty_model_size(mode): """ Test empty model size is zero. """ model = EmptyModule() @@ -336,3 +359,32 @@ def test_lazy_model_summary(): # https://github.com/pytorch/pytorch/issues/58350 assert summary.total_parameters == 7 assert summary.trainable_parameters == 7 + + +def test_max_depth_equals_mode_interface(): + """Test model.summarize(full/top) interface mapping matches max_depth""" + model = DeepNestedModel() + + summary_top = model.summarize(mode="top") + summary_0 = model.summarize(max_depth=1) + assert str(summary_top) == str(summary_0) + + summary_full = model.summarize(mode="full") + summary_minus1 = model.summarize(max_depth=-1) + assert str(summary_full) == str(summary_minus1) + + +@pytest.mark.parametrize('max_depth', [-1, 0, 1, 3, 999]) +def test_max_depth_param(max_depth): + """Test that only the modules up to the desired depth are shown""" + model = DeepNestedModel() + summary = ModelSummary(model, max_depth=max_depth) + for lname in summary.layer_names: + if max_depth >= 0: + assert lname.count(".") < max_depth + + +@pytest.mark.parametrize('max_depth', [-99, -2, "invalid"]) +def test_raise_invalid_max_depth_value(max_depth): + with pytest.raises(ValueError, match=f"`max_depth` can be -1, 0 or > 0, got {max_depth}"): + DeepNestedModel().summarize(max_depth=max_depth) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 5fe6bb210d..2ac4b110d8 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -16,6 +16,7 @@ import pytest from pytorch_lightning import Trainer from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.plugins.training_type import DDPPlugin, DDPSpawnPlugin from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.model_helpers import is_overridden @@ -249,3 +250,12 @@ def test_v1_6_0_ddp_plugin_task_idx(): plugin = DDPPlugin() with pytest.deprecated_call(match='Use `DDPPlugin.local_rank` instead'): _ = plugin.task_idx + + +def test_v1_6_0_deprecated_model_summary_mode(tmpdir): + model = BoringModel() + with pytest.deprecated_call(match="Argument `mode` in `ModelSummary` is deprecated in v1.4"): + ModelSummary(model, mode="top") + + with pytest.deprecated_call(match="Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"): + model.summarize(mode="top")