Add Profiler table kwargs (#17662)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
00909ba3ff
commit
dea1ff6633
|
@ -21,7 +21,7 @@ from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TY
|
|||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from torch.autograd.profiler import record_function
|
||||
from torch.autograd.profiler import EventList, record_function
|
||||
|
||||
from lightning.fabric.accelerators.cuda import is_cuda_available
|
||||
from lightning.pytorch.profilers.profiler import Profiler
|
||||
|
@ -30,7 +30,6 @@ from lightning.pytorch.utilities.imports import _KINETO_AVAILABLE
|
|||
from lightning.pytorch.utilities.rank_zero import rank_zero_warn, WarningCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.autograd.profiler import EventList
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from lightning.pytorch.core.module import LightningModule
|
||||
|
@ -239,6 +238,7 @@ class PyTorchProfiler(Profiler):
|
|||
row_limit: int = 20,
|
||||
sort_by_key: Optional[str] = None,
|
||||
record_module_names: bool = True,
|
||||
table_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**profiler_kwargs: Any,
|
||||
) -> None:
|
||||
r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
|
||||
|
@ -279,6 +279,8 @@ class PyTorchProfiler(Profiler):
|
|||
|
||||
record_module_names: Whether to add module names while recording autograd operation.
|
||||
|
||||
table_kwargs: Dictionary with keyword arguments for the summary table.
|
||||
|
||||
\**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
|
||||
|
||||
Raises:
|
||||
|
@ -296,6 +298,7 @@ class PyTorchProfiler(Profiler):
|
|||
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
|
||||
self._record_module_names = record_module_names
|
||||
self._profiler_kwargs = profiler_kwargs
|
||||
self._table_kwargs = table_kwargs if table_kwargs is not None else {}
|
||||
|
||||
self.profiler: Optional[_PROFILER] = None
|
||||
self.function_events: Optional["EventList"] = None
|
||||
|
@ -314,6 +317,19 @@ class PyTorchProfiler(Profiler):
|
|||
f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
|
||||
)
|
||||
|
||||
for key in self._table_kwargs:
|
||||
if key in {"sort_by", "row_limit"}:
|
||||
raise KeyError(
|
||||
f"Found invalid table_kwargs key: {key}. This is already a positional argument of the Profiler."
|
||||
)
|
||||
valid_table_keys = set(inspect.signature(EventList.table).parameters.keys()) - {
|
||||
"self",
|
||||
"sort_by",
|
||||
"row_limit",
|
||||
}
|
||||
if key not in valid_table_keys:
|
||||
raise KeyError(f"Found invalid table_kwargs key: {key}. Should be within {valid_table_keys}.")
|
||||
|
||||
def _init_kineto(self, profiler_kwargs: Any) -> None:
|
||||
has_schedule = "schedule" in profiler_kwargs
|
||||
self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs
|
||||
|
@ -485,7 +501,7 @@ class PyTorchProfiler(Profiler):
|
|||
self.function_events.export_chrome_trace(path_to_trace)
|
||||
|
||||
data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes)
|
||||
table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit)
|
||||
table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit, **self._table_kwargs)
|
||||
|
||||
recorded_stats = {"records": table}
|
||||
return self._stats_to_str(recorded_stats)
|
||||
|
|
|
@ -619,3 +619,38 @@ def test_profile_callbacks(tmpdir):
|
|||
e.name == "[pl][profile][Callback]EarlyStopping{'monitor': 'train_loss', 'mode': 'min'}.on_validation_start"
|
||||
for e in pytorch_profiler.function_events
|
||||
)
|
||||
|
||||
|
||||
@RunIf(min_python="3.10")
|
||||
def test_profiler_table_kwargs_summary_length(tmpdir):
|
||||
"""Test if setting max_name_column_width in table_kwargs changes table width."""
|
||||
|
||||
summaries = []
|
||||
# Default table_kwargs (None) sets max_name_column_width to 55
|
||||
for table_kwargs in [{"max_name_column_width": 1}, {"max_name_column_width": 5}, None]:
|
||||
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None, table_kwargs=table_kwargs)
|
||||
|
||||
with pytorch_profiler.profile("a"):
|
||||
torch.ones(1)
|
||||
pytorch_profiler.describe()
|
||||
summaries.append(pytorch_profiler.summary())
|
||||
|
||||
# Check if setting max_name_column_width results in a wider table (more dashes)
|
||||
assert summaries[0].count("-") < summaries[1].count("-")
|
||||
assert summaries[1].count("-") < summaries[2].count("-")
|
||||
|
||||
|
||||
def test_profiler_invalid_table_kwargs(tmpdir):
|
||||
"""Test if passing invalid keyword arguments raise expected error."""
|
||||
|
||||
for key in {"row_limit", "sort_by"}:
|
||||
with pytest.raises(
|
||||
KeyError,
|
||||
match=f"Found invalid table_kwargs key: {key}. This is already a positional argument of the Profiler.",
|
||||
):
|
||||
PyTorchProfiler(table_kwargs={key: None}, dirpath=tmpdir, filename="profile")
|
||||
|
||||
for key in {"self", "non_existent_keyword_arg"}:
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
PyTorchProfiler(table_kwargs={key: None}, dirpath=tmpdir, filename="profile")
|
||||
assert exc_info.value.args[0].startswith(f"Found invalid table_kwargs key: {key}.")
|
||||
|
|
Loading…
Reference in New Issue