From 2ec39e275ec10014da6b72af3cab4efad08a096e Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 17 Sep 2021 16:24:16 +0530 Subject: [PATCH] Add `RichModelSummary` callback (#9546) --- CHANGELOG.md | 3 + docs/source/extensions/callbacks.rst | 2 + pyproject.toml | 1 + pytorch_lightning/callbacks/__init__.py | 4 +- .../callbacks/rich_model_summary.py | 109 ++++++++++++++++++ .../trainer/connectors/callback_connector.py | 22 +++- tests/callbacks/test_rich_model_summary.py | 35 ++++++ 7 files changed, 171 insertions(+), 5 deletions(-) create mode 100644 pytorch_lightning/callbacks/rich_model_summary.py create mode 100644 tests/callbacks/test_rich_model_summary.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6147c33c25..d3b4a38dd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389)) +- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546)) + + ### Changed - `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)). diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index d5cfdfc811..ad61c10a7b 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -106,8 +106,10 @@ Lightning has a few built-in callbacks. LearningRateMonitor ModelCheckpoint ModelPruning + ModelSummary ProgressBar ProgressBarBase + RichModelSummary RichProgressBar QuantizationAwareTraining StochasticWeightAveraging diff --git a/pyproject.toml b/pyproject.toml index d205af7f0a..9981d6827e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ ignore_errors = "True" module = [ "pytorch_lightning.callbacks.model_summary", "pytorch_lightning.callbacks.pruning", + "pytorch_lightning.callbacks.rich_model_summary", "pytorch_lightning.loops.optimization.*", "pytorch_lightning.loops.evaluation_loop", "pytorch_lightning.trainer.connectors.checkpoint_connector", diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index f02518c14b..98cf5df7ca 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -24,6 +24,7 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining +from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor @@ -45,7 +46,8 @@ __all__ = [ "ProgressBar", "ProgressBarBase", "QuantizationAwareTraining", + "RichModelSummary", + "RichProgressBar", "StochasticWeightAveraging", "Timer", - "RichProgressBar", ] diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py new file mode 100644 index 0000000000..2e55665c44 --- /dev/null +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +from pytorch_lightning.callbacks import ModelSummary +from pytorch_lightning.utilities.imports import _RICH_AVAILABLE +from pytorch_lightning.utilities.model_summary import get_human_readable_count + +if _RICH_AVAILABLE: + from rich.console import Console + from rich.table import Table + + +class RichModelSummary(ModelSummary): + r""" + Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule` + with `rich text formatting `_. + + Install it with pip: + + .. code-block:: bash + + pip install rich + + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import RichModelSummary + + trainer = Trainer(callbacks=RichModelSummary()) + + You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar` + + .. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import RichProgressBar + + trainer = Trainer(callbacks=RichProgressBar()) + + Args: + max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the + layer summary off. + + Raises: + ImportError: + If required `rich` package is not installed on the device. + """ + + def __init__(self, max_depth: int = 1) -> None: + if not _RICH_AVAILABLE: + raise ImportError( + "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install rich`." + ) + super().__init__(max_depth) + + @staticmethod + def summarize( + summary_data: List[Tuple[str, List[str]]], + total_parameters: int, + trainable_parameters: int, + model_size: float, + ) -> None: + + console = Console() + + table = Table(header_style="bold magenta") + table.add_column(" ", style="dim") + table.add_column("Name", justify="left", no_wrap=True) + table.add_column("Type") + table.add_column("Params", justify="right") + + column_names = list(zip(*summary_data))[0] + + for column_name in ["In sizes", "Out sizes"]: + if column_name in column_names: + table.add_column(column_name, justify="right", style="white") + + rows = list(zip(*(arr[1] for arr in summary_data))) + for row in rows: + table.add_row(*row) + + console.print(table) + + parameters = [] + for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]: + parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10)) + + grid = Table.grid(expand=True) + grid.add_column() + grid.add_column() + + grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}") + grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}") + grid.add_row(f"[bold]Total params[/]: {parameters[2]}") + grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}") + + console.print(grid) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 2a2b5d15de..cafad831cb 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -15,7 +15,15 @@ import os from datetime import timedelta from typing import Dict, List, Optional, Union -from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks import ( + Callback, + ModelCheckpoint, + ModelSummary, + ProgressBar, + ProgressBarBase, + RichProgressBar, +) +from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -59,8 +67,6 @@ class CallbackConnector: # responsible to stop the training when max_time is reached. self._configure_timer_callback(max_time) - self._configure_model_summary_callback(weights_summary) - # init progress bar if process_position != 0: rank_zero_deprecation( @@ -70,6 +76,9 @@ class CallbackConnector: ) self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) + # configure the ModelSummary callback + self._configure_model_summary_callback(weights_summary) + # push all checkpoint callbacks to the end # it is important that these are the last callbacks to run self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks) @@ -102,7 +111,12 @@ class CallbackConnector: f" but got {weights_summary}", ) max_depth = ModelSummaryMode.get_max_depth(weights_summary) - model_summary = ModelSummary(max_depth=max_depth) + if self.trainer._progress_bar_callback is not None and isinstance( + self.trainer._progress_bar_callback, RichProgressBar + ): + model_summary = RichModelSummary(max_depth=max_depth) + else: + model_summary = ModelSummary(max_depth=max_depth) self.trainer.callbacks.append(model_summary) def _configure_swa_callbacks(self): diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py new file mode 100644 index 0000000000..99d557251f --- /dev/null +++ b/tests/callbacks/test_rich_model_summary.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import RichModelSummary, RichProgressBar +from pytorch_lightning.utilities.imports import _RICH_AVAILABLE +from tests.helpers.runif import RunIf + + +@RunIf(rich=True) +def test_rich_model_summary_callback(): + + trainer = Trainer(callbacks=RichProgressBar()) + + assert any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks) + assert isinstance(trainer.progress_bar_callback, RichProgressBar) + + +def test_rich_progress_bar_import_error(): + + if not _RICH_AVAILABLE: + with pytest.raises(ImportError, match="`RichModelSummary` requires `rich` to be installed."): + Trainer(callbacks=RichModelSummary())