Include hook's object name when profiling (#11026)
This commit is contained in:
parent
29eb9cccf2
commit
b1baf460d9
|
@ -114,6 +114,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Some configuration errors that were previously raised as `MisconfigurationException`s will now be raised as `ProcessRaisedException` (torch>=1.8) or as `Exception` (torch<1.8)
|
||||
|
||||
|
||||
- Changed profiler to index and display the names of the hooks with a new pattern [<base class>]<class>.<hook name> ([#11026](https://github.com/PyTorchLightning/pytorch-lightning/pull/11026))
|
||||
|
||||
|
||||
- Changed `batch_to_device` entry in profiling from stage-specific to generic, to match profiling of other hooks ([#11031](https://github.com/PyTorchLightning/pytorch-lightning/pull/11031))
|
||||
|
||||
|
||||
|
|
|
@ -131,51 +131,51 @@ class ScheduleWrapper:
|
|||
|
||||
@property
|
||||
def is_training(self) -> bool:
|
||||
return self._current_action is not None and (
|
||||
self._current_action.startswith("optimizer_step_with_closure_") or self._current_action == "training_step"
|
||||
return self._current_action.startswith("optimizer_step_with_closure_") or self._current_action.endswith(
|
||||
"training_step"
|
||||
)
|
||||
|
||||
@property
|
||||
def num_step(self) -> int:
|
||||
if self.is_training:
|
||||
return self._num_optimizer_step_with_closure
|
||||
if self._current_action == "validation_step":
|
||||
if self._current_action.endswith("validation_step"):
|
||||
return self._num_validation_step
|
||||
if self._current_action == "test_step":
|
||||
if self._current_action.endswith("test_step"):
|
||||
return self._num_test_step
|
||||
if self._current_action == "predict_step":
|
||||
if self._current_action.endswith("predict_step"):
|
||||
return self._num_predict_step
|
||||
return 0
|
||||
|
||||
def _step(self) -> None:
|
||||
if self.is_training:
|
||||
self._num_optimizer_step_with_closure += 1
|
||||
elif self._current_action == "validation_step":
|
||||
if self._start_action_name == "on_fit_start":
|
||||
elif self._current_action.endswith("validation_step"):
|
||||
if self._start_action_name.endswith("on_fit_start"):
|
||||
if self._num_optimizer_step_with_closure > 0:
|
||||
self._num_validation_step += 1
|
||||
else:
|
||||
self._num_validation_step += 1
|
||||
elif self._current_action == "test_step":
|
||||
elif self._current_action.endswith("test_step"):
|
||||
self._num_test_step += 1
|
||||
elif self._current_action == "predict_step":
|
||||
elif self._current_action.endswith("predict_step"):
|
||||
self._num_predict_step += 1
|
||||
|
||||
@property
|
||||
def has_finished(self) -> bool:
|
||||
if self.is_training:
|
||||
return self._optimizer_step_with_closure_reached_end
|
||||
if self._current_action == "validation_step":
|
||||
if self._current_action.endswith("validation_step"):
|
||||
return self._validation_step_reached_end
|
||||
if self._current_action == "test_step":
|
||||
if self._current_action.endswith("test_step"):
|
||||
return self._test_step_reached_end
|
||||
if self._current_action == "predict_step":
|
||||
if self._current_action.endswith("predict_step"):
|
||||
return self._predict_step_reached_end
|
||||
return False
|
||||
|
||||
def __call__(self, num_step: int) -> "ProfilerAction":
|
||||
# ignore the provided input. Keep internal state instead.
|
||||
if self.has_finished:
|
||||
if self._current_action is None or self.has_finished:
|
||||
return ProfilerAction.NONE
|
||||
|
||||
self._step()
|
||||
|
@ -183,11 +183,11 @@ class ScheduleWrapper:
|
|||
if action == ProfilerAction.RECORD_AND_SAVE:
|
||||
if self.is_training:
|
||||
self._optimizer_step_with_closure_reached_end = True
|
||||
elif self._current_action == "validation_step":
|
||||
elif self._current_action.endswith("validation_step"):
|
||||
self._validation_step_reached_end = True
|
||||
elif self._current_action == "test_step":
|
||||
elif self._current_action.endswith("test_step"):
|
||||
self._test_step_reached_end = True
|
||||
elif self._current_action == "predict_step":
|
||||
elif self._current_action.endswith("predict_step"):
|
||||
self._predict_step_reached_end = True
|
||||
return action
|
||||
|
||||
|
@ -340,11 +340,11 @@ class PyTorchProfiler(BaseProfiler):
|
|||
trainer = self._lightning_module.trainer
|
||||
if self._schedule.is_training:
|
||||
return trainer.num_training_batches
|
||||
if self._schedule._current_action == "validation_step":
|
||||
if self._schedule._current_action.endswith("validation_step"):
|
||||
return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches)
|
||||
if self._schedule._current_action == "test_step":
|
||||
if self._schedule._current_action.endswith("test_step"):
|
||||
return sum(trainer.num_test_batches)
|
||||
if self._schedule._current_action == "predict_step":
|
||||
if self._schedule._current_action.endswith("predict_step"):
|
||||
return sum(trainer.num_predict_batches)
|
||||
|
||||
def _should_override_schedule(self) -> bool:
|
||||
|
@ -373,8 +373,7 @@ class PyTorchProfiler(BaseProfiler):
|
|||
return activities
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
if self.profiler is None and action_name in self._record_functions_start:
|
||||
|
||||
if self.profiler is None and any(action_name.endswith(func) for func in self._record_functions_start):
|
||||
# close profiler if it is already opened. might happen if 2 profilers
|
||||
# are created and the first one did not call `describe`
|
||||
if torch.autograd._profiler_enabled():
|
||||
|
@ -403,7 +402,10 @@ class PyTorchProfiler(BaseProfiler):
|
|||
|
||||
if (
|
||||
self.profiler is not None
|
||||
and (action_name in self._record_functions or action_name.startswith(self.RECORD_FUNCTION_PREFIX))
|
||||
and (
|
||||
any(action_name.endswith(func) for func in self._record_functions)
|
||||
or action_name.startswith(self.RECORD_FUNCTION_PREFIX)
|
||||
)
|
||||
and action_name not in self._recording_map
|
||||
):
|
||||
|
||||
|
@ -420,9 +422,9 @@ class PyTorchProfiler(BaseProfiler):
|
|||
return
|
||||
|
||||
if self.profiler is not None and (
|
||||
action_name in self.STEP_FUNCTIONS or action_name.startswith(self.STEP_FUNCTION_PREFIX)
|
||||
any(action_name.endswith(func) for func in self.STEP_FUNCTIONS)
|
||||
or action_name.startswith(self.STEP_FUNCTION_PREFIX)
|
||||
):
|
||||
|
||||
if self._schedule is not None:
|
||||
self._schedule.pre_step(action_name)
|
||||
|
||||
|
|
|
@ -1523,8 +1523,7 @@ class Trainer(
|
|||
prev_fx_name = pl_module._current_fx_name
|
||||
pl_module._current_fx_name = hook_name
|
||||
|
||||
# TODO: when profiling separate hook name by hook object name (e.g. Callback, LM)
|
||||
with self.profiler.profile(hook_name):
|
||||
with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
|
||||
output = fn(*args, **kwargs)
|
||||
|
||||
# restore current_fx when nested context
|
||||
|
@ -1563,7 +1562,7 @@ class Trainer(
|
|||
for callback in self.callbacks:
|
||||
fn = getattr(callback, hook_name)
|
||||
if callable(fn):
|
||||
with self.profiler.profile(hook_name):
|
||||
with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
|
||||
fn(self, self.lightning_module, *args, **kwargs)
|
||||
|
||||
if pl_module:
|
||||
|
@ -1584,7 +1583,7 @@ class Trainer(
|
|||
if not callable(fn):
|
||||
return
|
||||
|
||||
with self.profiler.profile(hook_name):
|
||||
with self.profiler.profile(f"[Strategy]{self.training_type_plugin.__class__.__name__}.{hook_name}"):
|
||||
output = fn(*args, **kwargs)
|
||||
|
||||
# restore current_fx when nested context
|
||||
|
|
|
@ -22,7 +22,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import StochasticWeightAveraging
|
||||
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
|
||||
from pytorch_lightning.loggers.base import LoggerCollection
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
|
||||
|
@ -310,10 +310,13 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
|
|||
gpus=2,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
expected = {"validation_step"}
|
||||
expected = {"[Strategy]DDPPlugin.validation_step"}
|
||||
if not _KINETO_AVAILABLE:
|
||||
expected |= {"training_step_and_backward", "training_step", "backward"}
|
||||
expected |= {
|
||||
"training_step_and_backward",
|
||||
"[Strategy]DDPPlugin.training_step",
|
||||
"[Strategy]DDPPlugin.backward",
|
||||
}
|
||||
for name in expected:
|
||||
assert sum(e.name == name for e in pytorch_profiler.function_events), name
|
||||
|
||||
|
@ -330,7 +333,7 @@ def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
|
|||
assert len(files) == 2, files
|
||||
local_rank = trainer.local_rank
|
||||
assert any(f"{local_rank}-optimizer_step_with_closure_" in f for f in files)
|
||||
assert any(f"{local_rank}-validation_step" in f for f in files)
|
||||
assert any(f"{local_rank}-[Strategy]DDPPlugin.validation_step" in f for f in files)
|
||||
|
||||
|
||||
@RunIf(standalone=True)
|
||||
|
@ -343,7 +346,7 @@ def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
|
|||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run, profiler=pytorch_profiler)
|
||||
trainer.fit(model)
|
||||
|
||||
assert sum(e.name == "validation_step" for e in pytorch_profiler.function_events)
|
||||
assert sum(e.name == "[Strategy]SingleDevicePlugin.validation_step" for e in pytorch_profiler.function_events)
|
||||
|
||||
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
|
||||
assert path.read_text("utf-8")
|
||||
|
@ -351,8 +354,6 @@ def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
|
|||
if _KINETO_AVAILABLE:
|
||||
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
|
||||
assert any(f"fit-{pytorch_profiler.filename}" in f for f in files)
|
||||
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
|
||||
assert path.read_text("utf-8")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fn, step_name", [("test", "test"), ("validate", "validation"), ("predict", "predict")])
|
||||
|
@ -365,7 +366,7 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
|
|||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
|
||||
getattr(trainer, fn)(model)
|
||||
|
||||
assert sum(e.name == f"{step_name}_step" for e in pytorch_profiler.function_events)
|
||||
assert sum(e.name.endswith(f"{step_name}_step") for e in pytorch_profiler.function_events)
|
||||
|
||||
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
|
||||
assert path.read_text("utf-8")
|
||||
|
@ -373,8 +374,6 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
|
|||
if _KINETO_AVAILABLE:
|
||||
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
|
||||
assert any(f"{fn}-{pytorch_profiler.filename}" in f for f in files)
|
||||
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
|
||||
assert path.read_text("utf-8")
|
||||
|
||||
|
||||
def test_pytorch_profiler_nested(tmpdir):
|
||||
|
@ -418,11 +417,9 @@ def test_pytorch_profiler_logger_collection(tmpdir):
|
|||
assert not look_for_trace(tmpdir)
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
# Wrap the logger in a list so it becomes a LoggerCollection
|
||||
logger = [TensorBoardLogger(save_dir=tmpdir)]
|
||||
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=logger, limit_train_batches=5, max_epochs=1)
|
||||
|
||||
assert isinstance(trainer.logger, LoggerCollection)
|
||||
trainer.fit(model)
|
||||
assert look_for_trace(tmpdir)
|
||||
|
@ -552,3 +549,25 @@ def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_confi
|
|||
getattr(trainer, trainer_fn)(model)
|
||||
assert trainer.profiler._schedule is None
|
||||
warning_cache.clear()
|
||||
|
||||
|
||||
def test_profile_callbacks(tmpdir):
|
||||
"""Checks if profiling callbacks works correctly, specifically when there are two of the same callback type."""
|
||||
|
||||
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", record_functions=set("on_train_end"))
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=1,
|
||||
profiler=pytorch_profiler,
|
||||
callbacks=[EarlyStopping("val_loss"), EarlyStopping("train_loss")],
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert sum(
|
||||
e.name == "[Callback]EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}.on_validation_start"
|
||||
for e in pytorch_profiler.function_events
|
||||
)
|
||||
assert sum(
|
||||
e.name == "[Callback]EarlyStopping{'monitor': 'train_loss', 'mode': 'min'}.on_validation_start"
|
||||
for e in pytorch_profiler.function_events
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue