feat: Add ModelSummary Callback (#9344)
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
4f8c3ba4a5
commit
d773407e59
|
@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `remove_checkpoint` to `CheckpointIO` plugin by moving the responsibility from `ModelCheckpoint` Callback ([#9373](https://github.com/PyTorchLightning/pytorch-lightning/pull/9373))
|
||||
|
||||
|
||||
- Added `ModelSummary` callback ([#9344](https://github.com/PyTorchLightning/pytorch-lightning/pull/9344))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))
|
||||
|
|
|
@ -19,6 +19,7 @@ from pytorch_lightning.callbacks.gradient_accumulation_scheduler import Gradient
|
|||
from pytorch_lightning.callbacks.lambda_function import LambdaCallback
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.model_summary import ModelSummary
|
||||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
|
||||
from pytorch_lightning.callbacks.pruning import ModelPruning
|
||||
|
@ -39,6 +40,7 @@ __all__ = [
|
|||
"LearningRateMonitor",
|
||||
"ModelCheckpoint",
|
||||
"ModelPruning",
|
||||
"ModelSummary",
|
||||
"BasePredictionWriter",
|
||||
"ProgressBar",
|
||||
"ProgressBarBase",
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Model Summary
|
||||
=============
|
||||
|
||||
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
||||
|
||||
The string representation of this summary prints a table with columns containing
|
||||
the name, type and number of parameters for each layer.
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities.model_summary import _format_summary_table, summarize
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelSummary(Callback):
|
||||
r"""
|
||||
Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`.
|
||||
|
||||
Args:
|
||||
max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
|
||||
layer summary off.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> from pytorch_lightning.callbacks import ModelSummary
|
||||
>>> trainer = Trainer(callbacks=[ModelSummary(max_depth=1)])
|
||||
"""
|
||||
|
||||
def __init__(self, max_depth: Optional[int] = 1):
|
||||
self._max_depth: int = max_depth
|
||||
|
||||
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if trainer.is_global_zero and self._max_depth is not None and not trainer.testing:
|
||||
model_summary = summarize(pl_module, max_depth=self._max_depth)
|
||||
|
||||
summary_data = model_summary._get_summary_data()
|
||||
total_parameters = model_summary.total_parameters
|
||||
trainable_parameters = model_summary.trainable_parameters
|
||||
model_size = model_summary.model_size
|
||||
|
||||
self.summarize(summary_data, total_parameters, trainable_parameters, model_size)
|
||||
|
||||
@staticmethod
|
||||
def summarize(
|
||||
summary_data: List[List[Union[str, List[str]]]],
|
||||
total_parameters: int,
|
||||
trainable_parameters: int,
|
||||
model_size: float,
|
||||
) -> None:
|
||||
summary_table = _format_summary_table(total_parameters, trainable_parameters, model_size, *summary_data)
|
||||
|
||||
log.info("\n" + summary_table)
|
|
@ -15,9 +15,9 @@ import os
|
|||
from datetime import timedelta
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ModelSummary, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
|
||||
|
||||
|
@ -34,6 +34,7 @@ class CallbackConnector:
|
|||
process_position: int,
|
||||
default_root_dir: Optional[str],
|
||||
weights_save_path: Optional[str],
|
||||
weights_summary: Optional[str],
|
||||
stochastic_weight_avg: bool,
|
||||
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
|
||||
):
|
||||
|
@ -58,6 +59,8 @@ class CallbackConnector:
|
|||
# responsible to stop the training when max_time is reached.
|
||||
self._configure_timer_callback(max_time)
|
||||
|
||||
self._configure_model_summary_callback(weights_summary)
|
||||
|
||||
# init progress bar
|
||||
if process_position != 0:
|
||||
rank_zero_deprecation(
|
||||
|
@ -89,6 +92,19 @@ class CallbackConnector:
|
|||
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
|
||||
self.trainer.callbacks.append(ModelCheckpoint())
|
||||
|
||||
def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:
|
||||
if any(isinstance(cb, ModelSummary) for cb in self.trainer.callbacks):
|
||||
return
|
||||
if weights_summary is not None:
|
||||
if weights_summary not in ModelSummaryMode.supported_types():
|
||||
raise MisconfigurationException(
|
||||
f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}",
|
||||
f" but got {weights_summary}",
|
||||
)
|
||||
max_depth = ModelSummaryMode.get_max_depth(weights_summary)
|
||||
model_summary = ModelSummary(max_depth=max_depth)
|
||||
self.trainer.callbacks.append(model_summary)
|
||||
|
||||
def _configure_swa_callbacks(self):
|
||||
if not self.trainer._stochastic_weight_avg:
|
||||
return
|
||||
|
|
|
@ -79,7 +79,6 @@ from pytorch_lightning.utilities.distributed import distributed_available
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.model_summary import ModelSummary, summarize
|
||||
from pytorch_lightning.utilities.seed import reset_seed
|
||||
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
|
||||
|
@ -407,11 +406,6 @@ class Trainer(
|
|||
# default .predict() loop
|
||||
self.predict_loop = PredictionLoop()
|
||||
|
||||
# training state
|
||||
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
|
||||
raise MisconfigurationException(
|
||||
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}"
|
||||
)
|
||||
self.weights_summary = weights_summary
|
||||
|
||||
# init callbacks
|
||||
|
@ -423,6 +417,7 @@ class Trainer(
|
|||
process_position,
|
||||
default_root_dir,
|
||||
weights_save_path,
|
||||
self.weights_summary,
|
||||
stochastic_weight_avg,
|
||||
max_time,
|
||||
)
|
||||
|
@ -1108,11 +1103,6 @@ class Trainer(
|
|||
# --------------------------
|
||||
self.call_hook("on_pretrain_routine_start")
|
||||
|
||||
# print model summary
|
||||
if self.is_global_zero and self.weights_summary is not None and not self.testing:
|
||||
max_depth = ModelSummary.MODES[self.weights_summary]
|
||||
summarize(self.lightning_module, max_depth=max_depth)
|
||||
|
||||
self.call_hook("on_pretrain_routine_end")
|
||||
|
||||
def _run_train(self) -> None:
|
||||
|
|
|
@ -23,6 +23,7 @@ from pytorch_lightning.utilities.enums import ( # noqa: F401
|
|||
DistributedType,
|
||||
GradClipAlgorithmType,
|
||||
LightningEnum,
|
||||
ModelSummaryMode,
|
||||
)
|
||||
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
|
||||
from pytorch_lightning.utilities.imports import ( # noqa: F401
|
||||
|
|
|
@ -73,7 +73,7 @@ class PrecisionType(LightningEnum):
|
|||
|
||||
|
||||
class DistributedType(LightningEnum):
|
||||
"""Define type of ditributed computing.
|
||||
"""Define type of distributed computing.
|
||||
|
||||
>>> # you can match the type with string
|
||||
>>> DistributedType.DDP == 'ddp'
|
||||
|
@ -147,3 +147,35 @@ class AutoRestartBatchKeys(LightningEnum):
|
|||
"""Defines special dictionary keys used to track captured dataset state with multiple workers."""
|
||||
|
||||
PL_RESTART_META = "__pl_restart_meta"
|
||||
|
||||
|
||||
class ModelSummaryMode(LightningEnum):
|
||||
# TODO: remove in v1.6 (as `mode` would be deprecated for `max_depth`)
|
||||
"""Define the Model Summary mode to be used.
|
||||
|
||||
Can be one of
|
||||
- `top`: only the top-level modules will be recorded (the children of the root module)
|
||||
- `full`: summarizes all layers and their submodules in the root module
|
||||
|
||||
>>> # you can match the type with string
|
||||
>>> ModelSummaryMode.TOP == 'TOP'
|
||||
True
|
||||
>>> # which is case invariant
|
||||
>>> ModelSummaryMode.TOP in ('top', 'FULL')
|
||||
True
|
||||
"""
|
||||
|
||||
TOP = "top"
|
||||
FULL = "full"
|
||||
|
||||
@staticmethod
|
||||
def get_max_depth(mode: str) -> int:
|
||||
if mode == ModelSummaryMode.TOP:
|
||||
return 1
|
||||
if mode == ModelSummaryMode.FULL:
|
||||
return -1
|
||||
raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.")
|
||||
|
||||
@staticmethod
|
||||
def supported_types() -> List[str]:
|
||||
return [x.value for x in ModelSummaryMode]
|
||||
|
|
|
@ -23,7 +23,7 @@ from torch import Tensor
|
|||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation
|
||||
from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
@ -185,21 +185,21 @@ class ModelSummary:
|
|||
0.530 Total estimated model params size (MB)
|
||||
"""
|
||||
|
||||
MODES = dict(top=1, full=-1) # TODO: remove in v1.6
|
||||
|
||||
def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] = 1):
|
||||
self._model = model
|
||||
|
||||
# temporary mapping from mode to max_depth
|
||||
if max_depth is None or mode is not None:
|
||||
if mode in ModelSummary.MODES:
|
||||
max_depth = ModelSummary.MODES[mode]
|
||||
if mode in ModelSummaryMode.supported_types():
|
||||
max_depth = ModelSummaryMode.get_max_depth(mode)
|
||||
rank_zero_deprecation(
|
||||
"Argument `mode` in `ModelSummary` is deprecated in v1.4"
|
||||
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
|
||||
)
|
||||
else:
|
||||
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
|
||||
raise MisconfigurationException(
|
||||
f"`mode` can be {', '.join(ModelSummaryMode.supported_types())}, got {mode}."
|
||||
)
|
||||
|
||||
if not isinstance(max_depth, int) or max_depth < -1:
|
||||
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
|
||||
|
@ -295,7 +295,7 @@ class ModelSummary:
|
|||
model(input_)
|
||||
model.train(mode) # restore mode of module
|
||||
|
||||
def __str__(self):
|
||||
def _get_summary_data(self):
|
||||
"""Makes a summary listing with:
|
||||
|
||||
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
|
||||
|
@ -310,6 +310,11 @@ class ModelSummary:
|
|||
arrays.append(["In sizes", self.in_sizes])
|
||||
arrays.append(["Out sizes", self.out_sizes])
|
||||
|
||||
return arrays
|
||||
|
||||
def __str__(self):
|
||||
arrays = self._get_summary_data()
|
||||
|
||||
total_parameters = self.total_parameters
|
||||
trainable_parameters = self.trainable_parameters
|
||||
model_size = self.model_size
|
||||
|
@ -445,16 +450,17 @@ def summarize(
|
|||
|
||||
# temporary mapping from mode to max_depth
|
||||
if max_depth is None:
|
||||
if mode in ModelSummary.MODES:
|
||||
max_depth = ModelSummary.MODES[mode]
|
||||
if mode in ModelSummaryMode.supported_types():
|
||||
max_depth = ModelSummaryMode.get_max_depth(mode)
|
||||
rank_zero_deprecation(
|
||||
"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
|
||||
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
|
||||
)
|
||||
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
|
||||
elif mode is not None:
|
||||
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
|
||||
raise MisconfigurationException(
|
||||
f"`mode` can be None, {', '.join(ModelSummaryMode.supported_types())}, got {mode}"
|
||||
)
|
||||
else:
|
||||
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
|
||||
log.info("\n" + str(model_summary))
|
||||
return model_summary
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelSummary
|
||||
from pytorch_lightning.utilities import ModelSummaryMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
def test_model_summary_callback_present_trainer():
|
||||
|
||||
trainer = Trainer()
|
||||
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||||
|
||||
trainer = Trainer(callbacks=ModelSummary())
|
||||
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||||
|
||||
|
||||
def test_model_summary_callback_with_weights_summary_none():
|
||||
|
||||
trainer = Trainer(weights_summary=None)
|
||||
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||||
|
||||
|
||||
def test_model_summary_callback_with_weights_summary():
|
||||
|
||||
trainer = Trainer(weights_summary="top")
|
||||
|
||||
model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0]
|
||||
assert model_summary_callback._max_depth == 1
|
||||
|
||||
trainer = Trainer(weights_summary="full")
|
||||
|
||||
model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0]
|
||||
assert model_summary_callback._max_depth == -1
|
||||
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match=f"`weights_summary` can be None, {', '.join(list(ModelSummaryMode))}"
|
||||
):
|
||||
_ = Trainer(weights_summary="invalid")
|
||||
|
||||
|
||||
def test_model_summary_callback_override_weights_summary_flag():
|
||||
|
||||
trainer = Trainer(callbacks=ModelSummary(), weights_summary=None)
|
||||
assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
|
||||
|
||||
|
||||
def test_custom_model_summary_callback_summarize(tmpdir):
|
||||
class CustomModelSummary(ModelSummary):
|
||||
@staticmethod
|
||||
def summarize(
|
||||
summary_data: List[List[Union[str, List[str]]]],
|
||||
total_parameters: int,
|
||||
trainable_parameters: int,
|
||||
model_size: float,
|
||||
) -> None:
|
||||
assert summary_data[1][0] == "Name"
|
||||
assert summary_data[1][1][0] == "layer"
|
||||
|
||||
assert summary_data[2][0] == "Type"
|
||||
assert summary_data[2][1][0] == "Linear"
|
||||
|
||||
assert summary_data[3][0] == "Params"
|
||||
assert total_parameters == 66
|
||||
assert trainable_parameters == 66
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, callbacks=CustomModelSummary(), max_steps=1)
|
||||
|
||||
trainer.fit(model)
|
|
@ -21,6 +21,7 @@ from pytorch_lightning.callbacks import (
|
|||
GradientAccumulationScheduler,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
ModelSummary,
|
||||
ProgressBar,
|
||||
)
|
||||
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
||||
|
@ -31,31 +32,32 @@ def test_checkpoint_callbacks_are_last(tmpdir):
|
|||
"""Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
|
||||
checkpoint1 = ModelCheckpoint(tmpdir)
|
||||
checkpoint2 = ModelCheckpoint(tmpdir)
|
||||
model_summary = ModelSummary()
|
||||
early_stopping = EarlyStopping()
|
||||
lr_monitor = LearningRateMonitor()
|
||||
progress_bar = ProgressBar()
|
||||
|
||||
# no model reference
|
||||
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
|
||||
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
cb_connector._attach_model_callbacks()
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]
|
||||
|
||||
# no model callbacks
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: []
|
||||
trainer.model = model
|
||||
cb_connector._attach_model_callbacks()
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, model_summary, checkpoint1, checkpoint2]
|
||||
|
||||
# with model-specific callbacks that substitute ones in Trainer
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2]
|
||||
model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
|
||||
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
|
||||
trainer.model = model
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
cb_connector._attach_model_callbacks()
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, checkpoint1, checkpoint2]
|
||||
assert trainer.callbacks == [progress_bar, lr_monitor, early_stopping, model_summary, checkpoint1, checkpoint2]
|
||||
|
||||
|
||||
class StatefulCallback0(Callback):
|
||||
|
@ -119,7 +121,9 @@ def test_attach_model_callbacks():
|
|||
def assert_composition(trainer_callbacks, model_callbacks, expected):
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: model_callbacks
|
||||
trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
|
||||
trainer = Trainer(
|
||||
checkpoint_callback=False, progress_bar_refresh_rate=0, weights_summary=None, callbacks=trainer_callbacks
|
||||
)
|
||||
trainer.model = model
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
cb_connector._attach_model_callbacks()
|
||||
|
|
|
@ -1,5 +1,19 @@
|
|||
from pytorch_lightning.utilities import DeviceType
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning.utilities.enums import DeviceType, ModelSummaryMode, PrecisionType
|
||||
|
||||
|
||||
def test_consistency():
|
||||
|
@ -18,3 +32,13 @@ def test_precision_supported_types():
|
|||
assert PrecisionType.supported_type("16")
|
||||
assert not PrecisionType.supported_type(1)
|
||||
assert not PrecisionType.supported_type("invalid")
|
||||
|
||||
|
||||
def test_model_summary_mode():
|
||||
assert ModelSummaryMode.supported_types() == ["top", "full"]
|
||||
assert ModelSummaryMode.TOP in ("top", "full")
|
||||
assert ModelSummaryMode.get_max_depth("top") == 1
|
||||
assert ModelSummaryMode.get_max_depth("full") == -1
|
||||
|
||||
with pytest.raises(ValueError, match=f"`mode` can be {', '.join(list(ModelSummaryMode))}, got invalid."):
|
||||
ModelSummaryMode.get_max_depth("invalid")
|
||||
|
|
Loading…
Reference in New Issue