lightning/tests/tests_pytorch/profilers/test_profiler.py

627 lines
23 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import platform
import time
from copy import deepcopy
from unittest.mock import patch
import numpy as np
import pytest
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profilers.pytorch import RegisterRecordFunction, warning_cache
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE
from tests_pytorch.helpers.runif import RunIf
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
def _get_python_cprofile_total_duration(profile):
return sum(x.inlinetime for x in profile.getstats())
def _sleep_generator(durations):
"""the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long it
takes to call __next__"""
for duration in durations:
time.sleep(duration)
yield duration
@pytest.fixture
def simple_profiler():
return SimpleProfiler()
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_simple_profiler_durations(simple_profiler, action: str, expected: list):
"""Ensure the reported durations are reasonably accurate."""
for duration in expected:
with simple_profiler.profile(action):
time.sleep(duration)
# different environments have different precision when it comes to time.sleep()
# see: https://github.com/Lightning-AI/lightning/issues/796
np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2)
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with simple_profiler.profile("no-op"):
pass
durations = np.array(simple_profiler.recorded_durations["no-op"])
assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE)
def test_simple_profiler_value_errors(simple_profiler):
"""Ensure errors are raised where expected."""
action = "test"
with pytest.raises(ValueError):
simple_profiler.stop(action)
simple_profiler.start(action)
with pytest.raises(ValueError):
simple_profiler.start(action)
simple_profiler.stop(action)
def test_simple_profiler_deepcopy(tmpdir):
simple_profiler = SimpleProfiler(dirpath=tmpdir, filename="test")
simple_profiler.describe()
assert deepcopy(simple_profiler)
def test_simple_profiler_dirpath(tmpdir):
"""Ensure the profiler dirpath defaults to `trainer.log_dir` when not present."""
profiler = SimpleProfiler(filename="profiler")
assert profiler.dirpath is None
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, profiler=profiler)
trainer.fit(model)
expected = tmpdir / "lightning_logs" / "version_0"
assert trainer.log_dir == expected
assert profiler.dirpath == trainer.log_dir
assert expected.join("fit-profiler.txt").exists()
def test_simple_profiler_with_nonexisting_log_dir(tmpdir):
"""Ensure the profiler dirpath defaults to `trainer.log_dir`and creates it when not present."""
nonexisting_tmpdir = tmpdir / "nonexisting"
profiler = SimpleProfiler(filename="profiler")
assert profiler.dirpath is None
model = BoringModel()
trainer = Trainer(
default_root_dir=nonexisting_tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler
)
trainer.fit(model)
expected = nonexisting_tmpdir / "lightning_logs" / "version_0"
assert expected.exists()
assert trainer.log_dir == expected
assert profiler.dirpath == trainer.log_dir
assert expected.join("fit-profiler.txt").exists()
def test_simple_profiler_with_nonexisting_dirpath(tmpdir):
"""Ensure the profiler creates non-existing dirpath."""
nonexisting_tmpdir = tmpdir / "nonexisting"
profiler = SimpleProfiler(dirpath=nonexisting_tmpdir, filename="profiler")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler
)
trainer.fit(model)
assert nonexisting_tmpdir.exists()
assert nonexisting_tmpdir.join("fit-profiler.txt").exists()
@RunIf(skip_windows=True)
def test_simple_profiler_distributed_files(tmpdir):
"""Ensure the proper files are saved in distributed."""
profiler = SimpleProfiler(dirpath=tmpdir, filename="profiler")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=2,
strategy="ddp_spawn",
accelerator="cpu",
devices=2,
profiler=profiler,
logger=False,
)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
actual = set(os.listdir(profiler.dirpath))
expected = {f"{stage}-profiler-{rank}.txt" for stage in ("fit", "validate", "test") for rank in (0, 1)}
assert actual == expected
for f in profiler.dirpath.listdir():
assert f.read_text("utf-8")
def test_simple_profiler_logs(tmpdir, caplog, simple_profiler):
"""Ensure that the number of printed logs is correct."""
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2, profiler=simple_profiler, logger=False)
with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler"):
trainer.fit(model)
trainer.test(model)
assert caplog.text.count("Profiler Report") == 2
@pytest.mark.parametrize("extended", [True, False])
@patch("time.monotonic", return_value=70)
def test_simple_profiler_summary(tmpdir, extended):
"""Test the summary of `SimpleProfiler`."""
profiler = SimpleProfiler(extended=extended)
profiler.start_time = 63.0
hooks = [
"on_train_start",
"on_train_end",
"on_train_epoch_start",
"on_train_epoch_end",
"on_before_batch_transfer",
"on_fit_start",
]
sometime = 0.773434
sep = os.linesep
max_action_len = len("on_before_batch_transfer")
for i, hook in enumerate(hooks):
with profiler.profile(hook):
pass
profiler.recorded_durations[hook] = [sometime + i]
if extended:
header_string = (
f"{sep}| {'Action':<{max_action_len}s}\t| {'Mean duration (s)':<15}\t| {'Num calls':<15}\t|"
f" {'Total time (s)':<15}\t| {'Percentage %':<15}\t|"
)
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-'* output_string_len}"
expected_text = (
f"Profiler Report{sep}"
f"{sep_lines}"
f"{sep}| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |" # noqa: E501
f"{sep_lines}"
f"{sep}| Total | - | 6 | 7.0 | 100 % |" # noqa: E501
f"{sep_lines}"
f"{sep}| on_fit_start | 5.7734 | 1 | 5.7734 | 82.478 |" # noqa: E501
f"{sep}| on_before_batch_transfer | 4.7734 | 1 | 4.7734 | 68.192 |" # noqa: E501
f"{sep}| on_train_epoch_end | 3.7734 | 1 | 3.7734 | 53.906 |" # noqa: E501
f"{sep}| on_train_epoch_start | 2.7734 | 1 | 2.7734 | 39.62 |" # noqa: E501
f"{sep}| on_train_end | 1.7734 | 1 | 1.7734 | 25.335 |" # noqa: E501
f"{sep}| on_train_start | 0.77343 | 1 | 0.77343 | 11.049 |" # noqa: E501
f"{sep_lines}{sep}"
)
else:
header_string = (
f"{sep}| {'Action':<{max_action_len}s}\t| {'Mean duration (s)':<15}\t| {'Total time (s)':<15}\t|"
)
output_string_len = len(header_string.expandtabs())
sep_lines = f"{sep}{'-'* output_string_len}"
expected_text = (
f"Profiler Report{sep}"
f"{sep_lines}"
f"{sep}| Action | Mean duration (s) | Total time (s) |"
f"{sep_lines}"
f"{sep}| on_fit_start | 5.7734 | 5.7734 |"
f"{sep}| on_before_batch_transfer | 4.7734 | 4.7734 |"
f"{sep}| on_train_epoch_end | 3.7734 | 3.7734 |"
f"{sep}| on_train_epoch_start | 2.7734 | 2.7734 |"
f"{sep}| on_train_end | 1.7734 | 1.7734 |"
f"{sep}| on_train_start | 0.77343 | 0.77343 |"
f"{sep_lines}{sep}"
)
summary = profiler.summary().expandtabs()
assert expected_text == summary
@pytest.fixture
def advanced_profiler(tmpdir):
return AdvancedProfiler(dirpath=tmpdir, filename="profiler")
@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
def test_advanced_profiler_durations(advanced_profiler, action: str, expected: list):
for duration in expected:
with advanced_profiler.profile(action):
time.sleep(duration)
# different environments have different precision when it comes to time.sleep()
# see: https://github.com/Lightning-AI/lightning/issues/796
recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
expected_total_duration = np.sum(expected)
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
@pytest.mark.flaky(reruns=3)
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
"""ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with advanced_profiler.profile("no-op"):
pass
action_profile = advanced_profiler.profiled_actions["no-op"]
total_duration = _get_python_cprofile_total_duration(action_profile)
average_duration = total_duration / n_iter
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
"""ensure the profiler won't fail when reporting the summary."""
# record at least one event
with advanced_profiler.profile("test"):
pass
# log to stdout and print to file
advanced_profiler.describe()
path = advanced_profiler.dirpath / f"{advanced_profiler.filename}.txt"
data = path.read_text("utf-8")
assert len(data) > 0
def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"
with pytest.raises(ValueError):
advanced_profiler.stop(action)
advanced_profiler.start(action)
advanced_profiler.stop(action)
def test_advanced_profiler_deepcopy(advanced_profiler):
advanced_profiler.describe()
assert deepcopy(advanced_profiler)
@pytest.fixture
def pytorch_profiler(tmpdir):
return PyTorchProfiler(dirpath=tmpdir, filename="profiler")
@pytest.mark.xfail(raises=AssertionError, reason="TODO: Support after 1.11 profiler added")
def test_pytorch_profiler_describe(pytorch_profiler):
"""Ensure the profiler won't fail when reporting the summary."""
with pytorch_profiler.profile("on_test_start"):
torch.tensor(0)
# log to stdout and print to file
pytorch_profiler.describe()
path = pytorch_profiler.dirpath / f"{pytorch_profiler.filename}.txt"
data = path.read_text("utf-8")
assert len(data) > 0
def test_advanced_profiler_cprofile_deepcopy(tmpdir):
"""Checks for pickle issue reported in #6522."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler="advanced",
callbacks=StochasticWeightAveraging(swa_lrs=1e-2),
)
trainer.fit(model)
@RunIf(min_cuda_gpus=2, standalone=True)
def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
"""Ensure that the profiler can be given to the training and default step are properly recorded."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=5,
limit_val_batches=5,
profiler=pytorch_profiler,
strategy="ddp",
accelerator="gpu",
devices=2,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
expected = {"[pl][profile][Strategy]DDPStrategy.validation_step"}
if not _KINETO_AVAILABLE:
expected |= {
"[pl][profile][Strategy]DDPStrategy.training_step",
"[pl][profile][Strategy]DDPStrategy.backward",
}
for name in expected:
assert sum(e.name == name for e in pytorch_profiler.function_events), name
files = set(os.listdir(pytorch_profiler.dirpath))
expected = f"fit-profiler-{trainer.local_rank}.txt"
assert expected in files
path = pytorch_profiler.dirpath / expected
assert path.read_text("utf-8")
if _KINETO_AVAILABLE:
files = os.listdir(pytorch_profiler.dirpath)
files = [file for file in files if file.endswith(".json")]
assert len(files) == 2, files
local_rank = trainer.local_rank
assert any(f"{local_rank}-[Strategy]DDPStrategy.training_step" in f for f in files)
assert any(f"{local_rank}-[Strategy]DDPStrategy.validation_step" in f for f in files)
@pytest.mark.parametrize("fast_dev_run", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("boring_model_cls", [ManualOptimBoringModel, BoringModel])
def test_pytorch_profiler_trainer_fit(fast_dev_run, boring_model_cls, tmpdir):
"""Ensure that the profiler can be given to the trainer and test step are properly recorded."""
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile")
model = boring_model_cls()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run, profiler=pytorch_profiler)
trainer.fit(model)
assert sum(
e.name == "[pl][profile][Strategy]SingleDeviceStrategy.validation_step"
for e in pytorch_profiler.function_events
)
path = pytorch_profiler.dirpath / f"fit-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"fit-{pytorch_profiler.filename}" in f for f in files)
@pytest.mark.parametrize("fn, step_name", [("test", "test"), ("validate", "validation"), ("predict", "predict")])
@pytest.mark.parametrize("boring_model_cls", [BoringModel, ManualOptimBoringModel])
def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir):
"""Ensure that the profiler can be given to the trainer and test step are properly recorded."""
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profile", schedule=None)
model = boring_model_cls()
model.predict_dataloader = model.train_dataloader
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=2, profiler=pytorch_profiler)
getattr(trainer, fn)(model)
assert sum(e.name.endswith(f"{step_name}_step") for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"{fn}-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
if _KINETO_AVAILABLE:
files = sorted(file for file in os.listdir(tmpdir) if file.endswith(".json"))
assert any(f"{fn}-{pytorch_profiler.filename}" in f for f in files)
def test_pytorch_profiler_nested(tmpdir):
"""Ensure that the profiler handles nested context."""
pytorch_profiler = PyTorchProfiler(use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None)
with pytorch_profiler.profile("a"):
a = torch.ones(42)
with pytorch_profiler.profile("b"):
b = torch.zeros(42)
with pytorch_profiler.profile("c"):
_ = a + b
pytorch_profiler.describe()
events_name = {e.name for e in pytorch_profiler.function_events}
names = {"[pl][profile]a", "[pl][profile]b", "[pl][profile]c"}
ops = {"add", "empty", "fill_", "ones", "zero_", "zeros"}
ops = {"aten::" + op for op in ops}
expected = names.union(ops)
assert events_name == expected, (events_name, torch.__version__, platform.system())
def test_pytorch_profiler_multiple_loggers(tmpdir):
"""Tests whether the PyTorch profiler is able to write its trace locally when the Trainer is configured with
multiple loggers.
See issue #8157.
"""
def look_for_trace(trace_dir):
"""Determines if a directory contains a PyTorch trace."""
return any("trace.json" in filename for filename in os.listdir(trace_dir))
model = BoringModel()
loggers = [TensorBoardLogger(save_dir=tmpdir), CSVLogger(tmpdir)]
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", logger=loggers, limit_train_batches=5, max_epochs=1)
assert len(trainer.loggers) == 2
trainer.fit(model)
assert look_for_trace(tmpdir / "lightning_logs" / "version_0")
@RunIf(min_cuda_gpus=1, standalone=True)
def test_pytorch_profiler_nested_emit_nvtx():
"""This test check emit_nvtx is correctly supported."""
profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True)
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
profiler=profiler,
accelerator="gpu",
devices=1,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
def test_register_record_function(tmpdir):
use_cuda = torch.cuda.is_available()
pytorch_profiler = PyTorchProfiler(
export_to_chrome=False,
use_cuda=use_cuda,
dirpath=tmpdir,
filename="profiler",
schedule=None,
on_trace_ready=None,
)
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), torch.nn.Linear(1, 1))
model = TestModel()
input = torch.rand((1, 1))
if use_cuda:
model = model.cuda()
input = input.cuda()
with pytorch_profiler.profile("a"):
with RegisterRecordFunction(model):
model(input)
pytorch_profiler.describe()
event_names = [e.name for e in pytorch_profiler.function_events]
assert "[pl][module]torch.nn.modules.container.Sequential: layer" in event_names
assert "[pl][module]torch.nn.modules.linear.Linear: layer.0" in event_names
assert "[pl][module]torch.nn.modules.activation.ReLU: layer.1" in event_names
assert "[pl][module]torch.nn.modules.linear.Linear: layer.2" in event_names
@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
def test_profiler_teardown(tmpdir, cls):
"""This test checks if profiler teardown method is called when trainer is exiting."""
class TestCallback(Callback):
def on_fit_end(self, trainer, *args, **kwargs) -> None:
# describe sets it to None
assert trainer.profiler._output_file is None
profiler = cls(dirpath=tmpdir, filename="profiler")
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1, profiler=profiler, callbacks=[TestCallback()])
trainer.fit(model)
assert profiler._output_file is None
def test_pytorch_profiler_deepcopy(tmpdir):
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler", schedule=None)
pytorch_profiler.start("on_train_start")
torch.tensor(1)
pytorch_profiler.describe()
assert deepcopy(pytorch_profiler)
@pytest.mark.parametrize(
["profiler", "expected"],
[
(None, PassThroughProfiler),
(SimpleProfiler(), SimpleProfiler),
(AdvancedProfiler(), AdvancedProfiler),
("simple", SimpleProfiler),
("Simple", SimpleProfiler),
("advanced", AdvancedProfiler),
("pytorch", PyTorchProfiler),
],
)
def test_trainer_profiler_correct_args(profiler, expected):
kwargs = {"profiler": profiler} if profiler is not None else {}
trainer = Trainer(**kwargs)
assert isinstance(trainer.profiler, expected)
def test_trainer_profiler_incorrect_str_arg():
with pytest.raises(
MisconfigurationException,
match=r"When passing string value for the `profiler` parameter of `Trainer`, it can only be one of.*",
):
Trainer(profiler="unknown_profiler")
@pytest.mark.skipif(not _KINETO_AVAILABLE, reason="Requires PyTorch Profiler Kineto")
@pytest.mark.parametrize(
["trainer_config", "trainer_fn"],
[
({"limit_train_batches": 4, "limit_val_batches": 7}, "fit"),
({"limit_train_batches": 7, "limit_val_batches": 4, "num_sanity_val_steps": 0}, "fit"),
(
{
"limit_train_batches": 7,
"limit_val_batches": 2,
},
"fit",
),
({"limit_val_batches": 4}, "validate"),
({"limit_test_batches": 4}, "test"),
({"limit_predict_batches": 4}, "predict"),
],
)
def test_pytorch_profiler_raises_warning_for_limited_steps(tmpdir, trainer_config, trainer_fn):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, profiler="pytorch", max_epochs=1, **trainer_config)
warning_cache.clear()
with pytest.warns(UserWarning, match="not enough steps to properly record traces"):
getattr(trainer, trainer_fn)(model)
assert trainer.profiler._schedule is None
warning_cache.clear()
def test_profile_callbacks(tmpdir):
"""Checks if profiling callbacks works correctly, specifically when there are two of the same callback type."""
pytorch_profiler = PyTorchProfiler(dirpath=tmpdir, filename="profiler")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
profiler=pytorch_profiler,
callbacks=[EarlyStopping("val_loss"), EarlyStopping("train_loss")],
)
trainer.fit(model)
assert sum(
e.name == "[pl][profile][Callback]EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}.on_validation_start"
for e in pytorch_profiler.function_events
)
assert sum(
e.name == "[pl][profile][Callback]EarlyStopping{'monitor': 'train_loss', 'mode': 'min'}.on_validation_start"
for e in pytorch_profiler.function_events
)