diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index 1971646eeb..fe3ab1c189 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -21,7 +21,7 @@ from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TY import torch from torch import nn, Tensor -from torch.autograd.profiler import record_function +from torch.autograd.profiler import EventList, record_function from lightning.fabric.accelerators.cuda import is_cuda_available from lightning.pytorch.profilers.profiler import Profiler @@ -30,7 +30,6 @@ from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache if TYPE_CHECKING: - from torch.autograd.profiler import EventList from torch.utils.hooks import RemovableHandle from lightning.pytorch.core.module import LightningModule @@ -239,6 +238,7 @@ class PyTorchProfiler(Profiler): row_limit: int = 20, sort_by_key: Optional[str] = None, record_module_names: bool = True, + table_kwargs: Optional[Dict[str, Any]] = None, **profiler_kwargs: Any, ) -> None: r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. @@ -279,6 +279,8 @@ class PyTorchProfiler(Profiler): record_module_names: Whether to add module names while recording autograd operation. + table_kwargs: Dictionary with keyword arguments for the summary table. + \**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version Raises: @@ -296,6 +298,7 @@ class PyTorchProfiler(Profiler): self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" self._record_module_names = record_module_names self._profiler_kwargs = profiler_kwargs + self._table_kwargs = table_kwargs if table_kwargs is not None else {} self.profiler: Optional[_PROFILER] = None self.function_events: Optional["EventList"] = None @@ -314,6 +317,19 @@ class PyTorchProfiler(Profiler): f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " ) + for key in self._table_kwargs: + if key in {"sort_by", "row_limit"}: + raise KeyError( + f"Found invalid table_kwargs key: {key}. This is already a positional argument of the Profiler." + ) + valid_table_keys = set(inspect.signature(EventList.table).parameters.keys()) - { + "self", + "sort_by", + "row_limit", + } + if key not in valid_table_keys: + raise KeyError(f"Found invalid table_kwargs key: {key}. Should be within {valid_table_keys}.") + def _init_kineto(self, profiler_kwargs: Any) -> None: has_schedule = "schedule" in profiler_kwargs self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs @@ -485,7 +501,7 @@ class PyTorchProfiler(Profiler): self.function_events.export_chrome_trace(path_to_trace) data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) - table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) + table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit, **self._table_kwargs) recorded_stats = {"records": table} return self._stats_to_str(recorded_stats) diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index e5c9098e3f..01ee07770e 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -619,3 +619,38 @@ def test_profile_callbacks(tmpdir): e.name == "[pl][profile][Callback]EarlyStopping{'monitor': 'train_loss', 'mode': 'min'}.on_validation_start" for e in pytorch_profiler.function_events ) + + +@RunIf(min_python="3.10") +def test_profiler_table_kwargs_summary_length(tmpdir): + """Test if setting max_name_column_width in table_kwargs changes table width.""" + + summaries = [] + # Default table_kwargs (None) sets max_name_column_width to 55 + for table_kwargs in [{"max_name_column_width": 1}, {"max_name_column_width": 5}, None]: + pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None, table_kwargs=table_kwargs) + + with pytorch_profiler.profile("a"): + torch.ones(1) + pytorch_profiler.describe() + summaries.append(pytorch_profiler.summary()) + + # Check if setting max_name_column_width results in a wider table (more dashes) + assert summaries[0].count("-") < summaries[1].count("-") + assert summaries[1].count("-") < summaries[2].count("-") + + +def test_profiler_invalid_table_kwargs(tmpdir): + """Test if passing invalid keyword arguments raise expected error.""" + + for key in {"row_limit", "sort_by"}: + with pytest.raises( + KeyError, + match=f"Found invalid table_kwargs key: {key}. This is already a positional argument of the Profiler.", + ): + PyTorchProfiler(table_kwargs={key: None}, dirpath=tmpdir, filename="profile") + + for key in {"self", "non_existent_keyword_arg"}: + with pytest.raises(KeyError) as exc_info: + PyTorchProfiler(table_kwargs={key: None}, dirpath=tmpdir, filename="profile") + assert exc_info.value.args[0].startswith(f"Found invalid table_kwargs key: {key}.")