feat: Add ModelSummary Callback (#9344)

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Kaushik B 2021-09-10 18:12:42 +05:30 committed by GitHub
parent 4f8c3ba4a5
commit d773407e59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 269 additions and 33 deletions

View File

@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373))
- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344))
### Changed
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))

View File

@ -19,6 +19,7 @@ from pytorch_lightning.callbacks.gradient_accumulation_scheduler import Gradient
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_summary import ModelSummary
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
@ -39,6 +40,7 @@ __all__ = [
"LearningRateMonitor",
"ModelCheckpoint",
"ModelPruning",
"ModelSummary",
"BasePredictionWriter",
"ProgressBar",
"ProgressBarBase",

View File

@ -0,0 +1,72 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model Summary
=============
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
The string representation of this summary prints a table with columns containing
the name, type and number of parameters for each layer.
"""
import logging
from typing import List, Optional, Union
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize
log = logging.getLogger(__name__)
class ModelSummary(Callback):
r"""
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
Args:
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
layer summary off.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelSummary
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
"""
def __init__(self, max_depth: Optional[int] = 1):
self._max_depth: int = max_depth
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.is_global_zero and self._max_depth is not None and not trainer.testing:
model_summary = summarize(pl_module, max_depth=self._max_depth)
summary_data = model_summary._get_summary_data()
total_parameters = model_summary.total_parameters
trainable_parameters = model_summary.trainable_parameters
model_size = model_summary.model_size
self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
@staticmethod
def summarize(
summary_data: List[List[Union[str, List[str]]]],
total_parameters: int,
trainable_parameters: int,
model_size: float,
) -> None:
summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data)
log.info("\n" + summary_table)

View File

@ -15,9 +15,9 @@ import os
from datetime import timedelta
from typing import Dict, List, Optional, Union
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
@ -34,6 +34,7 @@ class CallbackConnector:
process_position: int,
default_root_dir: Optional[str],
weights_save_path: Optional[str],
weights_summary: Optional[str],
stochastic_weight_avg: bool,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
):
@ -58,6 +59,8 @@ class CallbackConnector:
# responsible to stop the training when max_time is reached.
self._configure_timer_callback(max_time)
self._configure_model_summary_callback(weights_summary)
# init progress bar
if process_position != 0:
rank_zero_deprecation(
@ -89,6 +92,19 @@ class CallbackConnector:
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint())
def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
return
if weights_summary is not None:
if weights_summary not in ModelSummaryMode.supported_types():
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
f" but got {weights_summary}",
)
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
model_summary = ModelSummary(max_depth=max_depth)
self.trainer.callbacks.append(model_summary)
def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
return

View File

@ -79,7 +79,6 @@ from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
@ -407,11 +406,6 @@ class Trainer(
# default .predict() loop
self.predict_loop = PredictionLoop()
# training state
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}"
)
self.weights_summary = weights_summary
# init callbacks
@ -423,6 +417,7 @@ class Trainer(
process_position,
default_root_dir,
weights_save_path,
self.weights_summary,
stochastic_weight_avg,
max_time,
)
@ -1108,11 +1103,6 @@ class Trainer(
# --------------------------
self.call_hook("on_pretrain_routine_start")
# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
max_depth = ModelSummary.MODES[self.weights_summary]
summarize(self.lightning_module, max_depth=max_depth)
self.call_hook("on_pretrain_routine_end")
def _run_train(self) -> None:

View File

@ -23,6 +23,7 @@ from pytorch_lightning.utilities.enums import ( # noqa: F401
DistributedType,
GradClipAlgorithmType,
LightningEnum,
ModelSummaryMode,
)
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
from pytorch_lightning.utilities.imports import ( # noqa: F401

View File

@ -73,7 +73,7 @@ class PrecisionType(LightningEnum):
class DistributedType(LightningEnum):
"""Define type of ditributed computing.
"""Define type of distributed computing.
>>> # you can match the type with string
>>> DistributedType.DDP == 'ddp'
@ -147,3 +147,35 @@ class AutoRestartBatchKeys(LightningEnum):
"""Defines special dictionary keys used to track captured dataset state with multiple workers."""
PL_RESTART_META = "__pl_restart_meta"
class ModelSummaryMode(LightningEnum):
# TODO: remove in v1.6 (as `mode` would be deprecated for `max_depth`)
"""Define the Model Summary mode to be used.
Can be one of
- `top`: only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module
>>> # you can match the type with string
>>> ModelSummaryMode.TOP == 'TOP'
True
>>> # which is case invariant
>>> ModelSummaryMode.TOP in ('top', 'FULL')
True
"""
TOP = "top"
FULL = "full"
@staticmethod
def get_max_depth(mode: str) -> int:
if mode == ModelSummaryMode.TOP:
return 1
if mode == ModelSummaryMode.FULL:
return -1
raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.")
@staticmethod
def supported_types() -> List[str]:
return [x.value for x in ModelSummaryMode]

View File

@ -23,7 +23,7 @@ from torch import Tensor
from torch.utils.hooks import RemovableHandle
import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation
from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.warnings import WarningCache
@ -185,21 +185,21 @@ class ModelSummary:
0.530 Total estimated model params size (MB)
"""
MODES = dict(top=1, full=-1) # TODO: remove in v1.6
def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1):
self._model = model
# temporary mapping from mode to max_depth
if max_depth is None or mode is not None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
if mode in ModelSummaryMode.supported_types():
max_depth = ModelSummaryMode.get_max_depth(mode)
rank_zero_deprecation(
"Argument `mode` in `ModelSummary` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
)
else:
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
raise MisconfigurationException(
f"`mode` can be {', '.join(ModelSummaryMode.supported_types())}, got {mode}."
)
if not isinstance(max_depth, int) or max_depth < -1:
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
@ -295,7 +295,7 @@ class ModelSummary:
model(input_)
model.train(mode) # restore mode of module
def __str__(self):
def _get_summary_data(self):
"""Makes a summary listing with:
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
@ -310,6 +310,11 @@ class ModelSummary:
arrays.append(["In sizes", self.in_sizes])
arrays.append(["Out sizes", self.out_sizes])
return arrays
def __str__(self):
arrays = self._get_summary_data()
total_parameters = self.total_parameters
trainable_parameters = self.trainable_parameters
model_size = self.model_size
@ -445,16 +450,17 @@ def summarize(
# temporary mapping from mode to max_depth
if max_depth is None:
if mode in ModelSummary.MODES:
max_depth = ModelSummary.MODES[mode]
if mode in ModelSummaryMode.supported_types():
max_depth = ModelSummaryMode.get_max_depth(mode)
rank_zero_deprecation(
"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
)
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
elif mode is not None:
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
raise MisconfigurationException(
f"`mode` can be None, {', '.join(ModelSummaryMode.supported_types())}, got {mode}"
)
else:
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
log.info("\n" + str(model_summary))
return model_summary

View File

@ -0,0 +1,86 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelSummary
from pytorch_lightning.utilities import ModelSummaryMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
def test_model_summary_callback_present_trainer():
trainer = Trainer()
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
trainer = Trainer(callbacks=ModelSummary())
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
def test_model_summary_callback_with_weights_summary_none():
trainer = Trainer(weights_summary=None)
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
def test_model_summary_callback_with_weights_summary():
trainer = Trainer(weights_summary="top")
model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0]
assert model_summary_callback._max_depth == 1
trainer = Trainer(weights_summary="full")
model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0]
assert model_summary_callback._max_depth == -1
with pytest.raises(
MisconfigurationException, match=f"`weights_summary` can be None, {', '.join(list(ModelSummaryMode))}"
):
_ = Trainer(weights_summary="invalid")
def test_model_summary_callback_override_weights_summary_flag():
trainer = Trainer(callbacks=ModelSummary(), weights_summary=None)
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
def test_custom_model_summary_callback_summarize(tmpdir):
class CustomModelSummary(ModelSummary):
@staticmethod
def summarize(
summary_data: List[List[Union[str, List[str]]]],
total_parameters: int,
trainable_parameters: int,
model_size: float,
) -> None:
assert summary_data[1][0] == "Name"
assert summary_data[1][1][0] == "layer"
assert summary_data[2][0] == "Type"
assert summary_data[2][1][0] == "Linear"
assert summary_data[3][0] == "Params"
assert total_parameters == 66
assert trainable_parameters == 66
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, callbacks=CustomModelSummary(), max_steps=1)
trainer.fit(model)

View File

@ -21,6 +21,7 @@ from pytorch_lightning.callbacks import (
GradientAccumulationScheduler,
LearningRateMonitor,
ModelCheckpoint,
ModelSummary,
ProgressBar,
)
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
@ -31,31 +32,32 @@ def test_checkpoint_callbacks_are_last(tmpdir):
"""Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
checkpoint1 = ModelCheckpoint(tmpdir)
checkpoint2 = ModelCheckpoint(tmpdir)
model_summary = ModelSummary()
early_stopping = EarlyStopping()
lr_monitor = LearningRateMonitor()
progress_bar = ProgressBar()
# no model reference
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]
# no model callbacks
model = LightningModule()
model.configure_callbacks = lambda: []
trainer.model = model
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]
# with model-specific callbacks that substitute ones in Trainer
model = LightningModule()
model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2]
model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
trainer.model = model
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2]
assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, model_summary, checkpoint1, checkpoint2]
class StatefulCallback0(Callback):
@ -119,7 +121,9 @@ def test_attach_model_callbacks():
def assert_composition(trainer_callbacks, model_callbacks, expected):
model = LightningModule()
model.configure_callbacks = lambda: model_callbacks
trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
trainer = Trainer(
checkpoint_callback=False, progress_bar_refresh_rate=0, weights_summary=None, callbacks=trainer_callbacks
)
trainer.model = model
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()

View File

@ -1,5 +1,19 @@
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.enums import PrecisionType
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pytorch_lightning.utilities.enums import DeviceType, ModelSummaryMode, PrecisionType
def test_consistency():
@ -18,3 +32,13 @@ def test_precision_supported_types():
assert PrecisionType.supported_type("16")
assert not PrecisionType.supported_type(1)
assert not PrecisionType.supported_type("invalid")
def test_model_summary_mode():
assert ModelSummaryMode.supported_types() == ["top", "full"]
assert ModelSummaryMode.TOP in ("top", "full")
assert ModelSummaryMode.get_max_depth("top") == 1
assert ModelSummaryMode.get_max_depth("full") == -1
with pytest.raises(ValueError, match=f"`mode` can be {', '.join(list(ModelSummaryMode))}, got invalid."):
ModelSummaryMode.get_max_depth("invalid")