diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 39eb6e7265..2cb0d04d0a 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made saving non-distributed checkpoints fully atomic ([#20011](https://github.com/Lightning-AI/pytorch-lightning/pull/20011)) +- Added `dump_stats` flag to `AdvancedProfiler` ([#19703](https://github.com/Lightning-AI/pytorch-lightning/issues/19703)) + - ### Changed diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 9e166b5e34..467b47124e 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -16,13 +16,17 @@ import cProfile import io import logging +import os import pstats +import tempfile from pathlib import Path from typing import Dict, Optional, Tuple, Union from typing_extensions import override +from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.pytorch.profilers.profiler import Profiler +from lightning.pytorch.utilities.rank_zero import rank_zero_only log = logging.getLogger(__name__) @@ -40,6 +44,7 @@ class AdvancedProfiler(Profiler): dirpath: Optional[Union[str, Path]] = None, filename: Optional[str] = None, line_count_restriction: float = 1.0, + dump_stats: bool = False, ) -> None: """ Args: @@ -54,6 +59,8 @@ class AdvancedProfiler(Profiler): reported for each action. either an integer (to select a count of lines), or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) + dump_stats: Whether to save raw profiler results. When ``True`` then ``dirpath`` must be provided. + Raises: ValueError: If you attempt to stop recording an action which was never started. @@ -61,6 +68,7 @@ class AdvancedProfiler(Profiler): super().__init__(dirpath=dirpath, filename=filename) self.profiled_actions: Dict[str, cProfile.Profile] = {} self.line_count_restriction = line_count_restriction + self.dump_stats = dump_stats @override def start(self, action_name: str) -> None: @@ -75,10 +83,27 @@ class AdvancedProfiler(Profiler): raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") pr.disable() + def _dump_stats(self, action_name: str, profile: cProfile.Profile) -> None: + assert self.dirpath + dst_filepath = os.path.join(self.dirpath, self._prepare_filename(action_name=action_name, extension=".prof")) + dst_fs = get_filesystem(dst_filepath) + dst_fs.mkdirs(self.dirpath, exist_ok=True) + # temporarily save to local since pstats can only dump into a local file + with tempfile.TemporaryDirectory( + prefix="test", suffix=str(rank_zero_only.rank), dir=os.getcwd() + ) as tmp_dir, dst_fs.open(dst_filepath, "wb") as dst_file: + src_filepath = os.path.join(tmp_dir, "tmp.prof") + profile.dump_stats(src_filepath) + src_fs = get_filesystem(src_filepath) + with src_fs.open(src_filepath, "rb") as src_file: + dst_file.write(src_file.read()) + @override def summary(self) -> str: recorded_stats = {} for action_name, pr in self.profiled_actions.items(): + if self.dump_stats: + self._dump_stats(action_name, pr) s = io.StringIO() ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative") ps.print_stats(self.line_count_restriction) diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 6f92788e27..5b0c13e605 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -308,6 +308,19 @@ def test_advanced_profiler_describe(tmp_path, advanced_profiler): assert len(data) > 0 +def test_advanced_profiler_dump_states(tmp_path): + advanced_profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True) + """Ensure the profiler dump stats during summary.""" + # record at least one event + with advanced_profiler.profile(action_name := "test"): + pass + # dump_stats to file + advanced_profiler.describe() + path = advanced_profiler.dirpath / f"{action_name}.prof" + data = path.read_bytes() + assert len(data) > 0 + + def test_advanced_profiler_value_errors(advanced_profiler): """Ensure errors are raised where expected.""" action = "test"