Add Profiler table kwargs (#17662)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Bas Krahmer 2023-05-23 17:31:17 +02:00 committed by GitHub
parent 00909ba3ff
commit dea1ff6633
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 3 deletions

View File

@ -21,7 +21,7 @@ from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TY
import torch
from torch import nn, Tensor
from torch.autograd.profiler import record_function
from torch.autograd.profiler import EventList, record_function
from lightning.fabric.accelerators.cuda import is_cuda_available
from lightning.pytorch.profilers.profiler import Profiler
@ -30,7 +30,6 @@ from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
if TYPE_CHECKING:
from torch.autograd.profiler import EventList
from torch.utils.hooks import RemovableHandle
from lightning.pytorch.core.module import LightningModule
@ -239,6 +238,7 @@ class PyTorchProfiler(Profiler):
row_limit: int = 20,
sort_by_key: Optional[str] = None,
record_module_names: bool = True,
table_kwargs: Optional[Dict[str, Any]] = None,
**profiler_kwargs: Any,
) -> None:
r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
@ -279,6 +279,8 @@ class PyTorchProfiler(Profiler):
record_module_names: Whether to add module names while recording autograd operation.
table_kwargs: Dictionary with keyword arguments for the summary table.
\**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
Raises:
@ -296,6 +298,7 @@ class PyTorchProfiler(Profiler):
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
self._record_module_names = record_module_names
self._profiler_kwargs = profiler_kwargs
self._table_kwargs = table_kwargs if table_kwargs is not None else {}
self.profiler: Optional[_PROFILER] = None
self.function_events: Optional["EventList"] = None
@ -314,6 +317,19 @@ class PyTorchProfiler(Profiler):
f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
)
for key in self._table_kwargs:
if key in {"sort_by", "row_limit"}:
raise KeyError(
f"Found invalid table_kwargs key: {key}. This is already a positional argument of the Profiler."
)
valid_table_keys = set(inspect.signature(EventList.table).parameters.keys()) - {
"self",
"sort_by",
"row_limit",
}
if key not in valid_table_keys:
raise KeyError(f"Found invalid table_kwargs key: {key}. Should be within {valid_table_keys}.")
def _init_kineto(self, profiler_kwargs: Any) -> None:
has_schedule = "schedule" in profiler_kwargs
self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs
@ -485,7 +501,7 @@ class PyTorchProfiler(Profiler):
self.function_events.export_chrome_trace(path_to_trace)
data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes)
table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit)
table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit, **self._table_kwargs)
recorded_stats = {"records": table}
return self._stats_to_str(recorded_stats)

View File

@ -619,3 +619,38 @@ def test_profile_callbacks(tmpdir):
e.name == "[pl][profile][Callback]EarlyStopping{'monitor': 'train_loss', 'mode': 'min'}.on_validation_start"
for e in pytorch_profiler.function_events
)
@RunIf(min_python="3.10")
def test_profiler_table_kwargs_summary_length(tmpdir):
"""Test if setting max_name_column_width in table_kwargs changes table width."""
summaries = []
# Default table_kwargs (None) sets max_name_column_width to 55
for table_kwargs in [{"max_name_column_width": 1}, {"max_name_column_width": 5}, None]:
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None, table_kwargs=table_kwargs)
with pytorch_profiler.profile("a"):
torch.ones(1)
pytorch_profiler.describe()
summaries.append(pytorch_profiler.summary())
# Check if setting max_name_column_width results in a wider table (more dashes)
assert summaries[0].count("-") < summaries[1].count("-")
assert summaries[1].count("-") < summaries[2].count("-")
def test_profiler_invalid_table_kwargs(tmpdir):
"""Test if passing invalid keyword arguments raise expected error."""
for key in {"row_limit", "sort_by"}:
with pytest.raises(
KeyError,
match=f"Found invalid table_kwargs key: {key}. This is already a positional argument of the Profiler.",
):
PyTorchProfiler(table_kwargs={key: None}, dirpath=tmpdir, filename="profile")
for key in {"self", "non_existent_keyword_arg"}:
with pytest.raises(KeyError) as exc_info:
PyTorchProfiler(table_kwargs={key: None}, dirpath=tmpdir, filename="profile")
assert exc_info.value.args[0].startswith(f"Found invalid table_kwargs key: {key}.")