Add `RichModelSummary` callback (#9546)

This commit is contained in:
Kaushik B 2021-09-17 16:24:16 +05:30 committed by GitHub
parent 11c93d903d
commit 2ec39e275e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 171 additions and 5 deletions

View File

@ -136,6 +136,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PL_RECONCILE_PROCESS` environment variable to enable process reconciliation regardless of cluster environment settings ([#9389](https://github.com/PyTorchLightning/pytorch-lightning/pull/9389))
- Added `RichModelSummary` callback ([#9546](https://github.com/PyTorchLightning/pytorch-lightning/pull/9546))
### Changed
- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).

View File

@ -106,8 +106,10 @@ Lightning has a few built-in callbacks.
LearningRateMonitor
ModelCheckpoint
ModelPruning
ModelSummary
ProgressBar
ProgressBarBase
RichModelSummary
RichProgressBar
QuantizationAwareTraining
StochasticWeightAveraging

View File

@ -63,6 +63,7 @@ ignore_errors = "True"
module = [
"pytorch_lightning.callbacks.model_summary",
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.callbacks.rich_model_summary",
"pytorch_lightning.loops.optimization.*",
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.checkpoint_connector",

View File

@ -24,6 +24,7 @@ from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
@ -45,7 +46,8 @@ __all__ = [
"ProgressBar",
"ProgressBarBase",
"QuantizationAwareTraining",
"RichModelSummary",
"RichProgressBar",
"StochasticWeightAveraging",
"Timer",
"RichProgressBar",
]

View File

@ -0,0 +1,109 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from pytorch_lightning.callbacks import ModelSummary
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from pytorch_lightning.utilities.model_summary import get_human_readable_count
if _RICH_AVAILABLE:
from rich.console import Console
from rich.table import Table
class RichModelSummary(ModelSummary):
r"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`
with `rich text formatting <https://github.com/willmcgugan/rich>`_.
Install it with pip:
.. code-block:: bash
pip install rich
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichModelSummary
trainer = Trainer(callbacks=RichModelSummary())
You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar`
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar
trainer = Trainer(callbacks=RichProgressBar())
Args:
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
layer summary off.
Raises:
ImportError:
If required `rich` package is not installed on the device.
"""
def __init__(self, max_depth: int = 1) -> None:
if not _RICH_AVAILABLE:
raise ImportError(
"`RichModelSummary` requires `rich` to be installed. Install it by running `pip install rich`."
)
super().__init__(max_depth)
@staticmethod
def summarize(
summary_data: List[Tuple[str, List[str]]],
total_parameters: int,
trainable_parameters: int,
model_size: float,
) -> None:
console = Console()
table = Table(header_style="bold magenta")
table.add_column(" ", style="dim")
table.add_column("Name", justify="left", no_wrap=True)
table.add_column("Type")
table.add_column("Params", justify="right")
column_names = list(zip(*summary_data))[0]
for column_name in ["In sizes", "Out sizes"]:
if column_name in column_names:
table.add_column(column_name, justify="right", style="white")
rows = list(zip(*(arr[1] for arr in summary_data)))
for row in rows:
table.add_row(*row)
console.print(table)
parameters = []
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
grid = Table.grid(expand=True)
grid.add_column()
grid.add_column()
grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}")
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
console.print(grid)

View File

@ -15,7 +15,15 @@ import os
from datetime import timedelta
from typing import Dict, List, Optional, Union
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import (
Callback,
ModelCheckpoint,
ModelSummary,
ProgressBar,
ProgressBarBase,
RichProgressBar,
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -59,8 +67,6 @@ class CallbackConnector:
# responsible to stop the training when max_time is reached.
self._configure_timer_callback(max_time)
self._configure_model_summary_callback(weights_summary)
# init progress bar
if process_position != 0:
rank_zero_deprecation(
@ -70,6 +76,9 @@ class CallbackConnector:
)
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
# configure the ModelSummary callback
self._configure_model_summary_callback(weights_summary)
# push all checkpoint callbacks to the end
# it is important that these are the last callbacks to run
self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks)
@ -102,7 +111,12 @@ class CallbackConnector:
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
model_summary = ModelSummary(max_depth=max_depth)
if self.trainer._progress_bar_callback is not None and isinstance(
self.trainer._progress_bar_callback, RichProgressBar
):
model_summary = RichModelSummary(max_depth=max_depth)
else:
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
def _configure_swa_callbacks(self):

View File

@ -0,0 +1,35 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichModelSummary, RichProgressBar
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from tests.helpers.runif import RunIf
@RunIf(rich=True)
def test_rich_model_summary_callback():
trainer = Trainer(callbacks=RichProgressBar())
assert any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks)
assert isinstance(trainer.progress_bar_callback, RichProgressBar)
def test_rich_progress_bar_import_error():
if not _RICH_AVAILABLE:
with pytest.raises(ImportError, match="`RichModelSummary` requires `rich` to be installed."):
Trainer(callbacks=RichModelSummary())