From d773407e596e73d728241d4425cf9d049f6e793e Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 10 Sep 2021 18:12:42 +0530 Subject: [PATCH] feat: Add ModelSummary Callback (#9344) Co-authored-by: Ethan Harris Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + pytorch_lightning/callbacks/__init__.py | 2 + pytorch_lightning/callbacks/model_summary.py | 72 ++++++++++++++++ .../trainer/connectors/callback_connector.py | 20 ++++- pytorch_lightning/trainer/trainer.py | 12 +-- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/enums.py | 34 +++++++- pytorch_lightning/utilities/model_summary.py | 28 +++--- tests/callbacks/test_model_summary.py | 86 +++++++++++++++++++ .../connectors/test_callback_connector.py | 16 ++-- tests/utilities/test_enums.py | 28 +++++- 11 files changed, 269 insertions(+), 33 deletions(-) create mode 100644 pytorch_lightning/callbacks/model_summary.py create mode 100644 tests/callbacks/test_model_summary.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c2651b6f4..fe8738266f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373)) +- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index d2c405b5c2..f02518c14b 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -19,6 +19,7 @@ from pytorch_lightning.callbacks.gradient_accumulation_scheduler import Gradient from pytorch_lightning.callbacks.lambda_function import LambdaCallback from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.model_summary import ModelSummary from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.pruning import ModelPruning @@ -39,6 +40,7 @@ __all__ = [ "LearningRateMonitor", "ModelCheckpoint", "ModelPruning", + "ModelSummary", "BasePredictionWriter", "ProgressBar", "ProgressBarBase", diff --git a/pytorch_lightning/callbacks/model_summary.py b/pytorch_lightning/callbacks/model_summary.py new file mode 100644 index 0000000000..1ebf744fa8 --- /dev/null +++ b/pytorch_lightning/callbacks/model_summary.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Model Summary +============= + +Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + +The string representation of this summary prints a table with columns containing +the name, type and number of parameters for each layer. + +""" +import logging +from typing import List, Optional, Union + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize + +log = logging.getLogger(__name__) + + +class ModelSummary(Callback): + r""" + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + + Args: + max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the + layer summary off. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import ModelSummary + >>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)]) + """ + + def __init__(self, max_depth: Optional[int] = 1): + self._max_depth: int = max_depth + + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if trainer.is_global_zero and self._max_depth is not None and not trainer.testing: + model_summary = summarize(pl_module, max_depth=self._max_depth) + + summary_data = model_summary._get_summary_data() + total_parameters = model_summary.total_parameters + trainable_parameters = model_summary.trainable_parameters + model_size = model_summary.model_size + + self.summarize(summary_data, total_parameters, trainable_parameters, model_size) + + @staticmethod + def summarize( + summary_data: List[List[Union[str, List[str]]]], + total_parameters: int, + trainable_parameters: int, + model_size: float, + ) -> None: + summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data) + + log.info("\n" + summary_table) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 57e95b4446..2a2b5d15de 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -15,9 +15,9 @@ import os from datetime import timedelta from typing import Dict, List, Optional, Union -from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import rank_zero_deprecation @@ -34,6 +34,7 @@ class CallbackConnector: process_position: int, default_root_dir: Optional[str], weights_save_path: Optional[str], + weights_summary: Optional[str], stochastic_weight_avg: bool, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, ): @@ -58,6 +59,8 @@ class CallbackConnector: # responsible to stop the training when max_time is reached. self._configure_timer_callback(max_time) + self._configure_model_summary_callback(weights_summary) + # init progress bar if process_position != 0: rank_zero_deprecation( @@ -89,6 +92,19 @@ class CallbackConnector: if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: self.trainer.callbacks.append(ModelCheckpoint()) + def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None: + if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks): + return + if weights_summary is not None: + if weights_summary not in ModelSummaryMode.supported_types(): + raise MisconfigurationException( + f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}", + f" but got {weights_summary}", + ) + max_depth = ModelSummaryMode.get_max_depth(weights_summary) + model_summary = ModelSummary(max_depth=max_depth) + self.trainer.callbacks.append(model_summary) + def _configure_swa_callbacks(self): if not self.trainer._stochastic_weight_avg: return diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 963741b995..eb905e85c8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -79,7 +79,6 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9 from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.model_summary import ModelSummary, summarize from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -407,11 +406,6 @@ class Trainer( # default .predict() loop self.predict_loop = PredictionLoop() - # training state - if weights_summary is not None and weights_summary not in ModelSummary.MODES: - raise MisconfigurationException( - f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}" - ) self.weights_summary = weights_summary # init callbacks @@ -423,6 +417,7 @@ class Trainer( process_position, default_root_dir, weights_save_path, + self.weights_summary, stochastic_weight_avg, max_time, ) @@ -1108,11 +1103,6 @@ class Trainer( # -------------------------- self.call_hook("on_pretrain_routine_start") - # print model summary - if self.is_global_zero and self.weights_summary is not None and not self.testing: - max_depth = ModelSummary.MODES[self.weights_summary] - summarize(self.lightning_module, max_depth=max_depth) - self.call_hook("on_pretrain_routine_end") def _run_train(self) -> None: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 66c7721599..5e6b99f1ce 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -23,6 +23,7 @@ from pytorch_lightning.utilities.enums import ( # noqa: F401 DistributedType, GradClipAlgorithmType, LightningEnum, + ModelSummaryMode, ) from pytorch_lightning.utilities.grads import grad_norm # noqa: F401 from pytorch_lightning.utilities.imports import ( # noqa: F401 diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 50c52fd57a..0a08f625ac 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -73,7 +73,7 @@ class PrecisionType(LightningEnum): class DistributedType(LightningEnum): - """Define type of ditributed computing. + """Define type of distributed computing. >>> # you can match the type with string >>> DistributedType.DDP == 'ddp' @@ -147,3 +147,35 @@ class AutoRestartBatchKeys(LightningEnum): """Defines special dictionary keys used to track captured dataset state with multiple workers.""" PL_RESTART_META = "__pl_restart_meta" + + +class ModelSummaryMode(LightningEnum): + # TODO: remove in v1.6 (as `mode` would be deprecated for `max_depth`) + """Define the Model Summary mode to be used. + + Can be one of + - `top`: only the top-level modules will be recorded (the children of the root module) + - `full`: summarizes all layers and their submodules in the root module + + >>> # you can match the type with string + >>> ModelSummaryMode.TOP == 'TOP' + True + >>> # which is case invariant + >>> ModelSummaryMode.TOP in ('top', 'FULL') + True + """ + + TOP = "top" + FULL = "full" + + @staticmethod + def get_max_depth(mode: str) -> int: + if mode == ModelSummaryMode.TOP: + return 1 + if mode == ModelSummaryMode.FULL: + return -1 + raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.") + + @staticmethod + def supported_types() -> List[str]: + return [x.value for x in ModelSummaryMode] diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index 727779162f..54b459964d 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -23,7 +23,7 @@ from torch import Tensor from torch.utils.hooks import RemovableHandle import pytorch_lightning as pl -from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation +from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from pytorch_lightning.utilities.warnings import WarningCache @@ -185,21 +185,21 @@ class ModelSummary: 0.530 Total estimated model params size (MB) """ - MODES = dict(top=1, full=-1) # TODO: remove in v1.6 - def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1): self._model = model # temporary mapping from mode to max_depth if max_depth is None or mode is not None: - if mode in ModelSummary.MODES: - max_depth = ModelSummary.MODES[mode] + if mode in ModelSummaryMode.supported_types(): + max_depth = ModelSummaryMode.get_max_depth(mode) rank_zero_deprecation( "Argument `mode` in `ModelSummary` is deprecated in v1.4" f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour." ) else: - raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.") + raise MisconfigurationException( + f"`mode` can be {', '.join(ModelSummaryMode.supported_types())}, got {mode}." + ) if not isinstance(max_depth, int) or max_depth < -1: raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.") @@ -295,7 +295,7 @@ class ModelSummary: model(input_) model.train(mode) # restore mode of module - def __str__(self): + def _get_summary_data(self): """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size @@ -310,6 +310,11 @@ class ModelSummary: arrays.append(["In sizes", self.in_sizes]) arrays.append(["Out sizes", self.out_sizes]) + return arrays + + def __str__(self): + arrays = self._get_summary_data() + total_parameters = self.total_parameters trainable_parameters = self.trainable_parameters model_size = self.model_size @@ -445,16 +450,17 @@ def summarize( # temporary mapping from mode to max_depth if max_depth is None: - if mode in ModelSummary.MODES: - max_depth = ModelSummary.MODES[mode] + if mode in ModelSummaryMode.supported_types(): + max_depth = ModelSummaryMode.get_max_depth(mode) rank_zero_deprecation( "Argument `mode` in `LightningModule.summarize` is deprecated in v1.4" f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior." ) model_summary = ModelSummary(lightning_module, max_depth=max_depth) elif mode is not None: - raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}") + raise MisconfigurationException( + f"`mode` can be None, {', '.join(ModelSummaryMode.supported_types())}, got {mode}" + ) else: model_summary = ModelSummary(lightning_module, max_depth=max_depth) - log.info("\n" + str(model_summary)) return model_summary diff --git a/tests/callbacks/test_model_summary.py b/tests/callbacks/test_model_summary.py new file mode 100644 index 0000000000..a0264186d9 --- /dev/null +++ b/tests/callbacks/test_model_summary.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Union + +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities import ModelSummaryMode +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel + + +def test_model_summary_callback_present_trainer(): + + trainer = Trainer() + assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + + trainer = Trainer(callbacks=ModelSummary()) + assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + + +def test_model_summary_callback_with_weights_summary_none(): + + trainer = Trainer(weights_summary=None) + assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + + +def test_model_summary_callback_with_weights_summary(): + + trainer = Trainer(weights_summary="top") + + model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0] + assert model_summary_callback._max_depth == 1 + + trainer = Trainer(weights_summary="full") + + model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0] + assert model_summary_callback._max_depth == -1 + + with pytest.raises( + MisconfigurationException, match=f"`weights_summary` can be None, {', '.join(list(ModelSummaryMode))}" + ): + _ = Trainer(weights_summary="invalid") + + +def test_model_summary_callback_override_weights_summary_flag(): + + trainer = Trainer(callbacks=ModelSummary(), weights_summary=None) + assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + + +def test_custom_model_summary_callback_summarize(tmpdir): + class CustomModelSummary(ModelSummary): + @staticmethod + def summarize( + summary_data: List[List[Union[str, List[str]]]], + total_parameters: int, + trainable_parameters: int, + model_size: float, + ) -> None: + assert summary_data[1][0] == "Name" + assert summary_data[1][1][0] == "layer" + + assert summary_data[2][0] == "Type" + assert summary_data[2][1][0] == "Linear" + + assert summary_data[3][0] == "Params" + assert total_parameters == 66 + assert trainable_parameters == 66 + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, callbacks=CustomModelSummary(), max_steps=1) + + trainer.fit(model) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 1973b2f2c7..713344ae2b 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -21,6 +21,7 @@ from pytorch_lightning.callbacks import ( GradientAccumulationScheduler, LearningRateMonitor, ModelCheckpoint, + ModelSummary, ProgressBar, ) from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector @@ -31,31 +32,32 @@ def test_checkpoint_callbacks_are_last(tmpdir): """Test that checkpoint callbacks always get moved to the end of the list, with preserved order.""" checkpoint1 = ModelCheckpoint(tmpdir) checkpoint2 = ModelCheckpoint(tmpdir) + model_summary = ModelSummary() early_stopping = EarlyStopping() lr_monitor = LearningRateMonitor() progress_bar = ProgressBar() # no model reference - trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2]) + trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() - assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] + assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2] # no model callbacks model = LightningModule() model.configure_callbacks = lambda: [] trainer.model = model cb_connector._attach_model_callbacks() - assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] + assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2] # with model-specific callbacks that substitute ones in Trainer model = LightningModule() - model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2] + model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() - assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2] + assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, model_summary, checkpoint1, checkpoint2] class StatefulCallback0(Callback): @@ -119,7 +121,9 @@ def test_attach_model_callbacks(): def assert_composition(trainer_callbacks, model_callbacks, expected): model = LightningModule() model.configure_callbacks = lambda: model_callbacks - trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) + trainer = Trainer( + checkpoint_callback=False, progress_bar_refresh_rate=0, weights_summary=None, callbacks=trainer_callbacks + ) trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() diff --git a/tests/utilities/test_enums.py b/tests/utilities/test_enums.py index c92ce938c7..df95a2b0f5 100644 --- a/tests/utilities/test_enums.py +++ b/tests/utilities/test_enums.py @@ -1,5 +1,19 @@ -from pytorch_lightning.utilities import DeviceType -from pytorch_lightning.utilities.enums import PrecisionType +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning.utilities.enums import DeviceType, ModelSummaryMode, PrecisionType def test_consistency(): @@ -18,3 +32,13 @@ def test_precision_supported_types(): assert PrecisionType.supported_type("16") assert not PrecisionType.supported_type(1) assert not PrecisionType.supported_type("invalid") + + +def test_model_summary_mode(): + assert ModelSummaryMode.supported_types() == ["top", "full"] + assert ModelSummaryMode.TOP in ("top", "full") + assert ModelSummaryMode.get_max_depth("top") == 1 + assert ModelSummaryMode.get_max_depth("full") == -1 + + with pytest.raises(ValueError, match=f"`mode` can be {', '.join(list(ModelSummaryMode))}, got invalid."): + ModelSummaryMode.get_max_depth("invalid")