From b1baf460d900a1e9982823b4428a3cd108a59d06 Mon Sep 17 00:00:00 2001
From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com>
Date: Mon, 20 Dec 2021 06:18:24 -0800
Subject: [PATCH] Include hook's object name when profiling (#11026)
---
CHANGELOG.md | 3 ++
pytorch_lightning/profiler/pytorch.py | 50 ++++++++++++++-------------
pytorch_lightning/trainer/trainer.py | 7 ++--
tests/profiler/test_profiler.py | 45 +++++++++++++++++-------
4 files changed, 64 insertions(+), 41 deletions(-)
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4de744df15..7c673a713d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -114,6 +114,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
+- Changed profiler to index and display the names of the hooks with a new pattern []. ([#11026](https://github.com/PyTorchLightning/pytorch-lightning/pull/11026))
+
+
- Changed `batch_to_device` entry in profiling from stage-specific to generic, to match profiling of other hooks ([#11031](https://github.com/PyTorchLightning/pytorch-lightning/pull/11031))
diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py
index f5c5968c0f..042a70966a 100644
--- a/pytorch_lightning/profiler/pytorch.py
+++ b/pytorch_lightning/profiler/pytorch.py
@@ -131,51 +131,51 @@ class ScheduleWrapper:
@property
def is_training(self) -> bool:
- return self._current_action is not None and (
- self._current_action.startswith("optimizer_step_with_closure_") or self._current_action == "training_step"
+ return self._current_action.startswith("optimizer_step_with_closure_") or self._current_action.endswith(
+ "training_step"
)
@property
def num_step(self) -> int:
if self.is_training:
return self._num_optimizer_step_with_closure
- if self._current_action == "validation_step":
+ if self._current_action.endswith("validation_step"):
return self._num_validation_step
- if self._current_action == "test_step":
+ if self._current_action.endswith("test_step"):
return self._num_test_step
- if self._current_action == "predict_step":
+ if self._current_action.endswith("predict_step"):
return self._num_predict_step
return 0
def _step(self) -> None:
if self.is_training:
self._num_optimizer_step_with_closure += 1
- elif self._current_action == "validation_step":
- if self._start_action_name == "on_fit_start":
+ elif self._current_action.endswith("validation_step"):
+ if self._start_action_name.endswith("on_fit_start"):
if self._num_optimizer_step_with_closure > 0:
self._num_validation_step += 1
else:
self._num_validation_step += 1
- elif self._current_action == "test_step":
+ elif self._current_action.endswith("test_step"):
self._num_test_step += 1
- elif self._current_action == "predict_step":
+ elif self._current_action.endswith("predict_step"):
self._num_predict_step += 1
@property
def has_finished(self) -> bool:
if self.is_training:
return self._optimizer_step_with_closure_reached_end
- if self._current_action == "validation_step":
+ if self._current_action.endswith("validation_step"):
return self._validation_step_reached_end
- if self._current_action == "test_step":
+ if self._current_action.endswith("test_step"):
return self._test_step_reached_end
- if self._current_action == "predict_step":
+ if self._current_action.endswith("predict_step"):
return self._predict_step_reached_end
return False
def __call__(self, num_step: int) -> "ProfilerAction":
# ignore the provided input. Keep internal state instead.
- if self.has_finished:
+ if self._current_action is None or self.has_finished:
return ProfilerAction.NONE
self._step()
@@ -183,11 +183,11 @@ class ScheduleWrapper:
if action == ProfilerAction.RECORD_AND_SAVE:
if self.is_training:
self._optimizer_step_with_closure_reached_end = True
- elif self._current_action == "validation_step":
+ elif self._current_action.endswith("validation_step"):
self._validation_step_reached_end = True
- elif self._current_action == "test_step":
+ elif self._current_action.endswith("test_step"):
self._test_step_reached_end = True
- elif self._current_action == "predict_step":
+ elif self._current_action.endswith("predict_step"):
self._predict_step_reached_end = True
return action
@@ -340,11 +340,11 @@ class PyTorchProfiler(BaseProfiler):
trainer = self._lightning_module.trainer
if self._schedule.is_training:
return trainer.num_training_batches
- if self._schedule._current_action == "validation_step":
+ if self._schedule._current_action.endswith("validation_step"):
return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches)
- if self._schedule._current_action == "test_step":
+ if self._schedule._current_action.endswith("test_step"):
return sum(trainer.num_test_batches)
- if self._schedule._current_action == "predict_step":
+ if self._schedule._current_action.endswith("predict_step"):
return sum(trainer.num_predict_batches)
def _should_override_schedule(self) -> bool:
@@ -373,8 +373,7 @@ class PyTorchProfiler(BaseProfiler):
return activities
def start(self, action_name: str) -> None:
- if self.profiler is None and action_name in self._record_functions_start:
-
+ if self.profiler is None and any(action_name.endswith(func) for func in self._record_functions_start):
# close profiler if it is already opened. might happen if 2 profilers
# are created and the first one did not call `describe`
if torch.autograd._profiler_enabled():
@@ -403,7 +402,10 @@ class PyTorchProfiler(BaseProfiler):
if (
self.profiler is not None
- and (action_name in self._record_functions or action_name.startswith(self.RECORD_FUNCTION_PREFIX))
+ and (
+ any(action_name.endswith(func) for func in self._record_functions)
+ or action_name.startswith(self.RECORD_FUNCTION_PREFIX)
+ )
and action_name not in self._recording_map
):
@@ -420,9 +422,9 @@ class PyTorchProfiler(BaseProfiler):
return
if self.profiler is not None and (
- action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX)
+ any(action_name.endswith(func) for func in self.STEP_FUNCTIONS)
+ or action_name.startswith(self.STEP_FUNCTION_PREFIX)
):
-
if self._schedule is not None:
self._schedule.pre_step(action_name)
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index dc02ada83f..246f53d280 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -1523,8 +1523,7 @@ class Trainer(
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
- # TODO: when profiling separate hook name by hook object name (e.g. Callback, LM)
- with self.profiler.profile(hook_name):
+ with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
output = fn(*args, **kwargs)
# restore current_fx when nested context
@@ -1563,7 +1562,7 @@ class Trainer(
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
- with self.profiler.profile(hook_name):
+ with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
fn(self, self.lightning_module, *args, **kwargs)
if pl_module:
@@ -1584,7 +1583,7 @@ class Trainer(
if not callable(fn):
return
- with self.profiler.profile(hook_name):
+ with self.profiler.profile(f"[Strategy]{self.training_type_plugin.__class__.__name__}.{hook_name}"):
output = fn(*args, **kwargs)
# restore current_fx when nested context
diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py
index 3108c83d4d..5395c37814 100644
--- a/tests/profiler/test_profiler.py
+++ b/tests/profiler/test_profiler.py
@@ -22,7 +22,7 @@ import pytest
import torch
from pytorch_lightning import Callback, Trainer
-from pytorch_lightning.callbacks import StochasticWeightAveraging
+from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
@@ -310,10 +310,13 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
gpus=2,
)
trainer.fit(model)
-
- expected = {"validation_step"}
+ expected = {"[Strategy]DDPPlugin.validation_step"}
if not _KINETO_AVAILABLE:
- expected |= {"training_step_and_backward", "training_step", "backward"}
+ expected |= {
+ "training_step_and_backward",
+ "[Strategy]DDPPlugin.training_step",
+ "[Strategy]DDPPlugin.backward",
+ }
for name in expected:
assert sum(e.name == name for e in pytorch_profiler.function_events), name
@@ -330,7 +333,7 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
assert len(files) == 2, files
local_rank = trainer.local_rank
assert any(f"{local_rank}-optimizer_step_with_closure_" in f for f in files)
- assert any(f"{local_rank}-validation_step" in f for f in files)
+ assert any(f"{local_rank}-[Strategy]DDPPlugin.validation_step" in f for f in files)
@RunIf(standalone=True)
@@ -343,7 +346,7 @@ def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run, profiler=pytorch_profiler)
trainer.fit(model)
- assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events)
+ assert sum(e.name == "[Strategy]SingleDevicePlugin.validation_step" for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
@@ -351,8 +354,6 @@ def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"fit-{pytorch_profiler.filename}" in f for f in files)
- path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
- assert path.read_text("utf-8")
@pytest.mark.parametrize("fn, step_name", [("test", "test"), ("validate", "validation"), ("predict", "predict")])
@@ -365,7 +366,7 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
getattr(trainer, fn)(model)
- assert sum(e.name == f"{step_name}_step" for e in pytorch_profiler.function_events)
+ assert sum(e.name.endswith(f"{step_name}_step") for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
@@ -373,8 +374,6 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"{fn}-{pytorch_profiler.filename}" in f for f in files)
- path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
- assert path.read_text("utf-8")
def test_pytorch_profiler_nested(tmpdir):
@@ -418,11 +417,9 @@ def test_pytorch_profiler_logger_collection(tmpdir):
assert not look_for_trace(tmpdir)
model = BoringModel()
-
# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
-
assert isinstance(trainer.logger, LoggerCollection)
trainer.fit(model)
assert look_for_trace(tmpdir)
@@ -552,3 +549,25 @@ def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_confi
getattr(trainer, trainer_fn)(model)
assert trainer.profiler._schedule is None
warning_cache.clear()
+
+
+def test_profile_callbacks(tmpdir):
+ """Checks if profiling callbacks works correctly, specifically when there are two of the same callback type."""
+
+ pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", record_functions=set("on_train_end"))
+ model = BoringModel()
+ trainer = Trainer(
+ default_root_dir=tmpdir,
+ fast_dev_run=1,
+ profiler=pytorch_profiler,
+ callbacks=[EarlyStopping("val_loss"), EarlyStopping("train_loss")],
+ )
+ trainer.fit(model)
+ assert sum(
+ e.name == "[Callback]EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}.on_validation_start"
+ for e in pytorch_profiler.function_events
+ )
+ assert sum(
+ e.name == "[Callback]EarlyStopping{'monitor': 'train_loss', 'mode': 'min'}.on_validation_start"
+ for e in pytorch_profiler.function_events
+ )