Enable dumping raw prof files in `AdvancedProfiler` (#19703)

Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
This commit is contained in:
Alexander Zhipa 2024-07-15 10:40:32 -04:00 committed by GitHub
parent 2dc9c3d933
commit 74470a6dbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 40 additions and 0 deletions

View File

@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011))
- Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703))
-
### Changed

View File

@ -16,13 +16,17 @@
import cProfile
import io
import logging
import os
import pstats
import tempfile
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from typing_extensions import override
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.profilers.profiler import Profiler
from lightning.pytorch.utilities.rank_zero import rank_zero_only
log = logging.getLogger(__name__)
@ -40,6 +44,7 @@ class AdvancedProfiler(Profiler):
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
line_count_restriction: float = 1.0,
dump_stats: bool = False,
) -> None:
"""
Args:
@ -54,6 +59,8 @@ class AdvancedProfiler(Profiler):
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided.
Raises:
ValueError:
If you attempt to stop recording an action which was never started.
@ -61,6 +68,7 @@ class AdvancedProfiler(Profiler):
super().__init__(dirpath=dirpath, filename=filename)
self.profiled_actions: Dict[str, cProfile.Profile] = {}
self.line_count_restriction = line_count_restriction
self.dump_stats = dump_stats
@override
def start(self, action_name: str) -> None:
@ -75,10 +83,27 @@ class AdvancedProfiler(Profiler):
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()
def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None:
assert self.dirpath
dst_filepath = os.path.join(self.dirpath, self._prepare_filename(action_name=action_name, extension=".prof"))
dst_fs = get_filesystem(dst_filepath)
dst_fs.mkdirs(self.dirpath, exist_ok=True)
# temporarily save to local since pstats can only dump into a local file
with tempfile.TemporaryDirectory(
prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd()
) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file:
src_filepath = os.path.join(tmp_dir, "tmp.prof")
profile.dump_stats(src_filepath)
src_fs = get_filesystem(src_filepath)
with src_fs.open(src_filepath, "rb") as src_file:
dst_file.write(src_file.read())
@override
def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
if self.dump_stats:
self._dump_stats(action_name, pr)
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative")
ps.print_stats(self.line_count_restriction)

View File

@ -308,6 +308,19 @@ def test_advanced_profiler_describe(tmp_path, advanced_profiler):
assert len(data) > 0
def test_advanced_profiler_dump_states(tmp_path):
advanced_profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
"""Ensure the profiler dump stats during summary."""
# record at least one event
with advanced_profiler.profile(action_name := "test"):
pass
# dump_stats to file
advanced_profiler.describe()
path = advanced_profiler.dirpath / f"{action_name}.prof"
data = path.read_bytes()
assert len(data) > 0
def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"