Enable dumping raw prof files in `AdvancedProfiler` (#19703)
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
parent
2dc9c3d933
commit
74470a6dbd
|
@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011))
|
||||
|
||||
- Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703))
|
||||
|
||||
-
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -16,13 +16,17 @@
|
|||
import cProfile
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import pstats
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning.pytorch.profilers.profiler import Profiler
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -40,6 +44,7 @@ class AdvancedProfiler(Profiler):
|
|||
dirpath: Optional[Union[str, Path]] = None,
|
||||
filename: Optional[str] = None,
|
||||
line_count_restriction: float = 1.0,
|
||||
dump_stats: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
|
@ -54,6 +59,8 @@ class AdvancedProfiler(Profiler):
|
|||
reported for each action. either an integer (to select a count of lines),
|
||||
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
|
||||
|
||||
dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If you attempt to stop recording an action which was never started.
|
||||
|
@ -61,6 +68,7 @@ class AdvancedProfiler(Profiler):
|
|||
super().__init__(dirpath=dirpath, filename=filename)
|
||||
self.profiled_actions: Dict[str, cProfile.Profile] = {}
|
||||
self.line_count_restriction = line_count_restriction
|
||||
self.dump_stats = dump_stats
|
||||
|
||||
@override
|
||||
def start(self, action_name: str) -> None:
|
||||
|
@ -75,10 +83,27 @@ class AdvancedProfiler(Profiler):
|
|||
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
|
||||
pr.disable()
|
||||
|
||||
def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None:
|
||||
assert self.dirpath
|
||||
dst_filepath = os.path.join(self.dirpath, self._prepare_filename(action_name=action_name, extension=".prof"))
|
||||
dst_fs = get_filesystem(dst_filepath)
|
||||
dst_fs.mkdirs(self.dirpath, exist_ok=True)
|
||||
# temporarily save to local since pstats can only dump into a local file
|
||||
with tempfile.TemporaryDirectory(
|
||||
prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()
|
||||
) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file:
|
||||
src_filepath = os.path.join(tmp_dir, "tmp.prof")
|
||||
profile.dump_stats(src_filepath)
|
||||
src_fs = get_filesystem(src_filepath)
|
||||
with src_fs.open(src_filepath, "rb") as src_file:
|
||||
dst_file.write(src_file.read())
|
||||
|
||||
@override
|
||||
def summary(self) -> str:
|
||||
recorded_stats = {}
|
||||
for action_name, pr in self.profiled_actions.items():
|
||||
if self.dump_stats:
|
||||
self._dump_stats(action_name, pr)
|
||||
s = io.StringIO()
|
||||
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative")
|
||||
ps.print_stats(self.line_count_restriction)
|
||||
|
|
|
@ -308,6 +308,19 @@ def test_advanced_profiler_describe(tmp_path, advanced_profiler):
|
|||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_advanced_profiler_dump_states(tmp_path):
|
||||
advanced_profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
|
||||
"""Ensure the profiler dump stats during summary."""
|
||||
# record at least one event
|
||||
with advanced_profiler.profile(action_name := "test"):
|
||||
pass
|
||||
# dump_stats to file
|
||||
advanced_profiler.describe()
|
||||
path = advanced_profiler.dirpath / f"{action_name}.prof"
|
||||
data = path.read_bytes()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_advanced_profiler_value_errors(advanced_profiler):
|
||||
"""Ensure errors are raised where expected."""
|
||||
action = "test"
|
||||
|
|
Loading…
Reference in New Issue