increase profiler test coverage (#1208)
* increase profiler test coverage * fix line length * tests for valueerror assertions
This commit is contained in:
parent
3be81cb54e
commit
4c2026bf9a
|
@ -96,7 +96,7 @@ class Profiler(BaseProfiler):
|
|||
def stop(self, action_name):
|
||||
end_time = time.monotonic()
|
||||
if action_name not in self.current_actions:
|
||||
raise ValueError( # pragma: no-cover
|
||||
raise ValueError(
|
||||
f"Attempting to stop recording an action ({action_name}) which was never started."
|
||||
)
|
||||
start_time = self.current_actions.pop(action_name)
|
||||
|
|
|
@ -1,13 +1,28 @@
|
|||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, Profiler
|
||||
|
||||
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001
|
||||
|
||||
|
||||
def _get_python_cprofile_total_duration(profile):
|
||||
return sum([x.inlinetime for x in profile.getstats()])
|
||||
|
||||
|
||||
def _sleep_generator(durations):
|
||||
"""
|
||||
the profile_iterable method needs an iterable in which we can ensure that we're
|
||||
properly timing how long it takes to call __next__
|
||||
"""
|
||||
for duration in durations:
|
||||
time.sleep(duration)
|
||||
yield duration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_profiler():
|
||||
profiler = Profiler()
|
||||
|
@ -35,6 +50,20 @@ def test_simple_profiler_durations(simple_profiler, action, expected):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
|
||||
def test_simple_profiler_iterable_durations(simple_profiler, action, expected):
|
||||
"""Ensure the reported durations are reasonably accurate."""
|
||||
iterable = _sleep_generator(expected)
|
||||
|
||||
for duration in simple_profiler.profile_iterable(iterable, action):
|
||||
pass
|
||||
|
||||
# we exclude the last item in the recorded durations since that's when StopIteration is raised
|
||||
np.testing.assert_allclose(
|
||||
simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2
|
||||
)
|
||||
|
||||
|
||||
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
|
||||
"""Ensure that the profiler doesn't introduce too much overhead during training."""
|
||||
for _ in range(n_iter):
|
||||
|
@ -50,10 +79,23 @@ def test_simple_profiler_describe(simple_profiler):
|
|||
simple_profiler.describe()
|
||||
|
||||
|
||||
def test_simple_profiler_value_errors(simple_profiler):
|
||||
"""Ensure errors are raised where expected."""
|
||||
|
||||
action = "test"
|
||||
with pytest.raises(ValueError):
|
||||
simple_profiler.stop(action)
|
||||
|
||||
simple_profiler.start(action)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
simple_profiler.start(action)
|
||||
|
||||
simple_profiler.stop(action)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
|
||||
def test_advanced_profiler_durations(advanced_profiler, action, expected):
|
||||
def _get_total_duration(profile):
|
||||
return sum([x.totaltime for x in profile.getstats()])
|
||||
|
||||
for duration in expected:
|
||||
with advanced_profiler.profile(action):
|
||||
|
@ -61,7 +103,24 @@ def test_advanced_profiler_durations(advanced_profiler, action, expected):
|
|||
|
||||
# different environments have different precision when it comes to time.sleep()
|
||||
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
|
||||
recored_total_duration = _get_total_duration(
|
||||
recored_total_duration = _get_python_cprofile_total_duration(
|
||||
advanced_profiler.profiled_actions[action]
|
||||
)
|
||||
expected_total_duration = np.sum(expected)
|
||||
np.testing.assert_allclose(
|
||||
recored_total_duration, expected_total_duration, rtol=0.2
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])])
|
||||
def test_advanced_profiler_iterable_durations(advanced_profiler, action, expected):
|
||||
"""Ensure the reported durations are reasonably accurate."""
|
||||
iterable = _sleep_generator(expected)
|
||||
|
||||
for duration in advanced_profiler.profile_iterable(iterable, action):
|
||||
pass
|
||||
|
||||
recored_total_duration = _get_python_cprofile_total_duration(
|
||||
advanced_profiler.profiled_actions[action]
|
||||
)
|
||||
expected_total_duration = np.sum(expected)
|
||||
|
@ -79,13 +138,33 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
|
|||
pass
|
||||
|
||||
action_profile = advanced_profiler.profiled_actions["no-op"]
|
||||
total_duration = sum([x.totaltime for x in action_profile.getstats()])
|
||||
total_duration = _get_python_cprofile_total_duration(action_profile)
|
||||
average_duration = total_duration / n_iter
|
||||
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE
|
||||
|
||||
|
||||
def test_advanced_profiler_describe(advanced_profiler):
|
||||
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
|
||||
"""
|
||||
ensure the profiler won't fail when reporting the summary
|
||||
"""
|
||||
# record at least one event
|
||||
with advanced_profiler.profile("test"):
|
||||
pass
|
||||
# log to stdout
|
||||
advanced_profiler.describe()
|
||||
# print to file
|
||||
advanced_profiler.output_filename = Path(tmpdir, "profiler.txt")
|
||||
advanced_profiler.describe()
|
||||
data = Path(advanced_profiler.output_filename).read_text()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_advanced_profiler_value_errors(advanced_profiler):
|
||||
"""Ensure errors are raised where expected."""
|
||||
|
||||
action = "test"
|
||||
with pytest.raises(ValueError):
|
||||
advanced_profiler.stop(action)
|
||||
|
||||
advanced_profiler.start(action)
|
||||
advanced_profiler.stop(action)
|
||||
|
|
Loading…
Reference in New Issue