Include hook's object name when profiling (#11026)

This commit is contained in:
Danielle Pintz 2021-12-20 06:18:24 -08:00 committed by GitHub
parent 29eb9cccf2
commit b1baf460d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 64 additions and 41 deletions

View File

@ -114,6 +114,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
- Changed profiler to index and display the names of the hooks with a new pattern [<base class>]<class>.<hook name> ([#11026](https://github.com/PyTorchLightning/pytorch-lightning/pull/11026))
- Changed `batch_to_device` entry in profiling from stage-specific to generic, to match profiling of other hooks ([#11031](https://github.com/PyTorchLightning/pytorch-lightning/pull/11031))

View File

@ -131,51 +131,51 @@ class ScheduleWrapper:
@property
def is_training(self) -> bool:
return self._current_action is not None and (
self._current_action.startswith("optimizer_step_with_closure_") or self._current_action == "training_step"
return self._current_action.startswith("optimizer_step_with_closure_") or self._current_action.endswith(
"training_step"
)
@property
def num_step(self) -> int:
if self.is_training:
return self._num_optimizer_step_with_closure
if self._current_action == "validation_step":
if self._current_action.endswith("validation_step"):
return self._num_validation_step
if self._current_action == "test_step":
if self._current_action.endswith("test_step"):
return self._num_test_step
if self._current_action == "predict_step":
if self._current_action.endswith("predict_step"):
return self._num_predict_step
return 0
def _step(self) -> None:
if self.is_training:
self._num_optimizer_step_with_closure += 1
elif self._current_action == "validation_step":
if self._start_action_name == "on_fit_start":
elif self._current_action.endswith("validation_step"):
if self._start_action_name.endswith("on_fit_start"):
if self._num_optimizer_step_with_closure > 0:
self._num_validation_step += 1
else:
self._num_validation_step += 1
elif self._current_action == "test_step":
elif self._current_action.endswith("test_step"):
self._num_test_step += 1
elif self._current_action == "predict_step":
elif self._current_action.endswith("predict_step"):
self._num_predict_step += 1
@property
def has_finished(self) -> bool:
if self.is_training:
return self._optimizer_step_with_closure_reached_end
if self._current_action == "validation_step":
if self._current_action.endswith("validation_step"):
return self._validation_step_reached_end
if self._current_action == "test_step":
if self._current_action.endswith("test_step"):
return self._test_step_reached_end
if self._current_action == "predict_step":
if self._current_action.endswith("predict_step"):
return self._predict_step_reached_end
return False
def __call__(self, num_step: int) -> "ProfilerAction":
# ignore the provided input. Keep internal state instead.
if self.has_finished:
if self._current_action is None or self.has_finished:
return ProfilerAction.NONE
self._step()
@ -183,11 +183,11 @@ class ScheduleWrapper:
if action == ProfilerAction.RECORD_AND_SAVE:
if self.is_training:
self._optimizer_step_with_closure_reached_end = True
elif self._current_action == "validation_step":
elif self._current_action.endswith("validation_step"):
self._validation_step_reached_end = True
elif self._current_action == "test_step":
elif self._current_action.endswith("test_step"):
self._test_step_reached_end = True
elif self._current_action == "predict_step":
elif self._current_action.endswith("predict_step"):
self._predict_step_reached_end = True
return action
@ -340,11 +340,11 @@ class PyTorchProfiler(BaseProfiler):
trainer = self._lightning_module.trainer
if self._schedule.is_training:
return trainer.num_training_batches
if self._schedule._current_action == "validation_step":
if self._schedule._current_action.endswith("validation_step"):
return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches)
if self._schedule._current_action == "test_step":
if self._schedule._current_action.endswith("test_step"):
return sum(trainer.num_test_batches)
if self._schedule._current_action == "predict_step":
if self._schedule._current_action.endswith("predict_step"):
return sum(trainer.num_predict_batches)
def _should_override_schedule(self) -> bool:
@ -373,8 +373,7 @@ class PyTorchProfiler(BaseProfiler):
return activities
def start(self, action_name: str) -> None:
if self.profiler is None and action_name in self._record_functions_start:
if self.profiler is None and any(action_name.endswith(func) for func in self._record_functions_start):
# close profiler if it is already opened. might happen if 2 profilers
# are created and the first one did not call `describe`
if torch.autograd._profiler_enabled():
@ -403,7 +402,10 @@ class PyTorchProfiler(BaseProfiler):
if (
self.profiler is not None
and (action_name in self._record_functions or action_name.startswith(self.RECORD_FUNCTION_PREFIX))
and (
any(action_name.endswith(func) for func in self._record_functions)
or action_name.startswith(self.RECORD_FUNCTION_PREFIX)
)
and action_name not in self._recording_map
):
@ -420,9 +422,9 @@ class PyTorchProfiler(BaseProfiler):
return
if self.profiler is not None and (
action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX)
any(action_name.endswith(func) for func in self.STEP_FUNCTIONS)
or action_name.startswith(self.STEP_FUNCTION_PREFIX)
):
if self._schedule is not None:
self._schedule.pre_step(action_name)

View File

@ -1523,8 +1523,7 @@ class Trainer(
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
# TODO: when profiling separate hook name by hook object name (e.g. Callback, LM)
with self.profiler.profile(hook_name):
with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
output = fn(*args, **kwargs)
# restore current_fx when nested context
@ -1563,7 +1562,7 @@ class Trainer(
for callback in self.callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
with self.profiler.profile(hook_name):
with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
fn(self, self.lightning_module, *args, **kwargs)
if pl_module:
@ -1584,7 +1583,7 @@ class Trainer(
if not callable(fn):
return
with self.profiler.profile(hook_name):
with self.profiler.profile(f"[Strategy]{self.training_type_plugin.__class__.__name__}.{hook_name}"):
output = fn(*args, **kwargs)
# restore current_fx when nested context

View File

@ -22,7 +22,7 @@ import pytest
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import StochasticWeightAveraging
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
@ -310,10 +310,13 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
gpus=2,
)
trainer.fit(model)
expected = {"validation_step"}
expected = {"[Strategy]DDPPlugin.validation_step"}
if not _KINETO_AVAILABLE:
expected |= {"training_step_and_backward", "training_step", "backward"}
expected |= {
"training_step_and_backward",
"[Strategy]DDPPlugin.training_step",
"[Strategy]DDPPlugin.backward",
}
for name in expected:
assert sum(e.name == name for e in pytorch_profiler.function_events), name
@ -330,7 +333,7 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
assert len(files) == 2, files
local_rank = trainer.local_rank
assert any(f"{local_rank}-optimizer_step_with_closure_" in f for f in files)
assert any(f"{local_rank}-validation_step" in f for f in files)
assert any(f"{local_rank}-[Strategy]DDPPlugin.validation_step" in f for f in files)
@RunIf(standalone=True)
@ -343,7 +346,7 @@ def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run, profiler=pytorch_profiler)
trainer.fit(model)
assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events)
assert sum(e.name == "[Strategy]SingleDevicePlugin.validation_step" for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
@ -351,8 +354,6 @@ def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"fit-{pytorch_profiler.filename}" in f for f in files)
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
@pytest.mark.parametrize("fn, step_name", [("test", "test"), ("validate", "validation"), ("predict", "predict")])
@ -365,7 +366,7 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
getattr(trainer, fn)(model)
assert sum(e.name == f"{step_name}_step" for e in pytorch_profiler.function_events)
assert sum(e.name.endswith(f"{step_name}_step") for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
@ -373,8 +374,6 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"{fn}-{pytorch_profiler.filename}" in f for f in files)
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
def test_pytorch_profiler_nested(tmpdir):
@ -418,11 +417,9 @@ def test_pytorch_profiler_logger_collection(tmpdir):
assert not look_for_trace(tmpdir)
model = BoringModel()
# Wrap the logger in a list so it becomes a LoggerCollection
logger = [TensorBoardLogger(save_dir=tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
assert isinstance(trainer.logger, LoggerCollection)
trainer.fit(model)
assert look_for_trace(tmpdir)
@ -552,3 +549,25 @@ def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_confi
getattr(trainer, trainer_fn)(model)
assert trainer.profiler._schedule is None
warning_cache.clear()
def test_profile_callbacks(tmpdir):
"""Checks if profiling callbacks works correctly, specifically when there are two of the same callback type."""
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", record_functions=set("on_train_end"))
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
profiler=pytorch_profiler,
callbacks=[EarlyStopping("val_loss"), EarlyStopping("train_loss")],
)
trainer.fit(model)
assert sum(
e.name == "[Callback]EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}.on_validation_start"
for e in pytorch_profiler.function_events
)
assert sum(
e.name == "[Callback]EarlyStopping{'monitor': 'train_loss', 'mode': 'min'}.on_validation_start"
for e in pytorch_profiler.function_events
)