110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import List, Tuple
|
|
|
|
from pytorch_lightning.callbacks import ModelSummary
|
|
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
|
|
from pytorch_lightning.utilities.model_summary import get_human_readable_count
|
|
|
|
if _RICH_AVAILABLE:
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
|
|
|
|
class RichModelSummary(ModelSummary):
|
|
r"""
|
|
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`
|
|
with `rich text formatting <https://github.com/willmcgugan/rich>`_.
|
|
|
|
Install it with pip:
|
|
|
|
.. code-block:: bash
|
|
|
|
pip install rich
|
|
|
|
.. code-block:: python
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.callbacks import RichModelSummary
|
|
|
|
trainer = Trainer(callbacks=RichModelSummary())
|
|
|
|
You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar`
|
|
|
|
.. code-block:: python
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.callbacks import RichProgressBar
|
|
|
|
trainer = Trainer(callbacks=RichProgressBar())
|
|
|
|
Args:
|
|
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
|
|
layer summary off.
|
|
|
|
Raises:
|
|
ImportError:
|
|
If required `rich` package is not installed on the device.
|
|
"""
|
|
|
|
def __init__(self, max_depth: int = 1) -> None:
|
|
if not _RICH_AVAILABLE:
|
|
raise ImportError(
|
|
"`RichModelSummary` requires `rich` to be installed. Install it by running `pip install rich`."
|
|
)
|
|
super().__init__(max_depth)
|
|
|
|
@staticmethod
|
|
def summarize(
|
|
summary_data: List[Tuple[str, List[str]]],
|
|
total_parameters: int,
|
|
trainable_parameters: int,
|
|
model_size: float,
|
|
) -> None:
|
|
|
|
console = Console()
|
|
|
|
table = Table(header_style="bold magenta")
|
|
table.add_column(" ", style="dim")
|
|
table.add_column("Name", justify="left", no_wrap=True)
|
|
table.add_column("Type")
|
|
table.add_column("Params", justify="right")
|
|
|
|
column_names = list(zip(*summary_data))[0]
|
|
|
|
for column_name in ["In sizes", "Out sizes"]:
|
|
if column_name in column_names:
|
|
table.add_column(column_name, justify="right", style="white")
|
|
|
|
rows = list(zip(*(arr[1] for arr in summary_data)))
|
|
for row in rows:
|
|
table.add_row(*row)
|
|
|
|
console.print(table)
|
|
|
|
parameters = []
|
|
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
|
|
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
|
|
|
|
grid = Table.grid(expand=True)
|
|
grid.add_column()
|
|
grid.add_column()
|
|
|
|
grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}")
|
|
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
|
|
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
|
|
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
|
|
|
|
console.print(grid)
|