lightning/tests/callbacks/test_rich_model_summary.py

66 lines
2.5 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from unittest import mock
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichModelSummary, RichProgressBar
from pytorch_lightning.utilities.model_summary import summarize
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@RunIf(rich=True)
def test_rich_model_summary_callback():
trainer = Trainer(callbacks=RichProgressBar())
assert any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks)
assert isinstance(trainer.progress_bar_callback, RichProgressBar)
def test_rich_progress_bar_import_error(monkeypatch):
import pytorch_lightning.callbacks.rich_model_summary as imports
monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
with pytest.raises(ModuleNotFoundError, match="`RichModelSummary` requires `rich` to be installed."):
RichModelSummary()
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Console.print", autospec=True)
@mock.patch("pytorch_lightning.callbacks.rich_model_summary.Table.add_row", autospec=True)
def test_rich_summary_tuples(mock_table_add_row, mock_console):
"""Ensure that tuples are converted into string, and print is called correctly."""
model_summary = RichModelSummary()
class TestModel(BoringModel):
@property
def example_input_array(self) -> Any:
return torch.randn(4, 32)
model = TestModel()
summary = summarize(model)
summary_data = summary._get_summary_data()
model_summary.summarize(summary_data=summary_data, total_parameters=1, trainable_parameters=1, model_size=1)
# ensure that summary was logged + the breakdown of model parameters
assert mock_console.call_count == 2
# assert that the input summary data was converted correctly
args, kwargs = mock_table_add_row.call_args_list[0]
assert args[1:] == ("0", "layer", "Linear", "66 ", "[4, 32]", "[4, 2]")