From 4c2026bf9af34f56f6f47dc8e919906c9be575a8 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Date: Tue, 24 Mar 2020 09:15:16 -0400 Subject: [PATCH] increase profiler test coverage (#1208) * increase profiler test coverage * fix line length * tests for valueerror assertions --- pytorch_lightning/profiler/profiler.py | 2 +- tests/test_profiler.py | 93 ++++++++++++++++++++++++-- 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/profiler/profiler.py b/pytorch_lightning/profiler/profiler.py index 00f375d596..0b04f7509a 100644 --- a/pytorch_lightning/profiler/profiler.py +++ b/pytorch_lightning/profiler/profiler.py @@ -96,7 +96,7 @@ class Profiler(BaseProfiler): def stop(self, action_name): end_time = time.monotonic() if action_name not in self.current_actions: - raise ValueError( # pragma: no-cover + raise ValueError( f"Attempting to stop recording an action ({action_name}) which was never started." ) start_time = self.current_actions.pop(action_name) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index d2bc1ebf40..43fa72df8a 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,13 +1,28 @@ +import tempfile import time +from pathlib import Path import numpy as np import pytest - -from pytorch_lightning.profiler import Profiler, AdvancedProfiler +from pytorch_lightning.profiler import AdvancedProfiler, Profiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 +def _get_python_cprofile_total_duration(profile): + return sum([x.inlinetime for x in profile.getstats()]) + + +def _sleep_generator(durations): + """ + the profile_iterable method needs an iterable in which we can ensure that we're + properly timing how long it takes to call __next__ + """ + for duration in durations: + time.sleep(duration) + yield duration + + @pytest.fixture def simple_profiler(): profiler = Profiler() @@ -35,6 +50,20 @@ def test_simple_profiler_durations(simple_profiler, action, expected): ) +@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])]) +def test_simple_profiler_iterable_durations(simple_profiler, action, expected): + """Ensure the reported durations are reasonably accurate.""" + iterable = _sleep_generator(expected) + + for duration in simple_profiler.profile_iterable(iterable, action): + pass + + # we exclude the last item in the recorded durations since that's when StopIteration is raised + np.testing.assert_allclose( + simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2 + ) + + def test_simple_profiler_overhead(simple_profiler, n_iter=5): """Ensure that the profiler doesn't introduce too much overhead during training.""" for _ in range(n_iter): @@ -50,10 +79,23 @@ def test_simple_profiler_describe(simple_profiler): simple_profiler.describe() +def test_simple_profiler_value_errors(simple_profiler): + """Ensure errors are raised where expected.""" + + action = "test" + with pytest.raises(ValueError): + simple_profiler.stop(action) + + simple_profiler.start(action) + + with pytest.raises(ValueError): + simple_profiler.start(action) + + simple_profiler.stop(action) + + @pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])]) def test_advanced_profiler_durations(advanced_profiler, action, expected): - def _get_total_duration(profile): - return sum([x.totaltime for x in profile.getstats()]) for duration in expected: with advanced_profiler.profile(action): @@ -61,7 +103,24 @@ def test_advanced_profiler_durations(advanced_profiler, action, expected): # different environments have different precision when it comes to time.sleep() # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 - recored_total_duration = _get_total_duration( + recored_total_duration = _get_python_cprofile_total_duration( + advanced_profiler.profiled_actions[action] + ) + expected_total_duration = np.sum(expected) + np.testing.assert_allclose( + recored_total_duration, expected_total_duration, rtol=0.2 + ) + + +@pytest.mark.parametrize("action,expected", [("a", [3, 1]), ("b", [2]), ("c", [1])]) +def test_advanced_profiler_iterable_durations(advanced_profiler, action, expected): + """Ensure the reported durations are reasonably accurate.""" + iterable = _sleep_generator(expected) + + for duration in advanced_profiler.profile_iterable(iterable, action): + pass + + recored_total_duration = _get_python_cprofile_total_duration( advanced_profiler.profiled_actions[action] ) expected_total_duration = np.sum(expected) @@ -79,13 +138,33 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): pass action_profile = advanced_profiler.profiled_actions["no-op"] - total_duration = sum([x.totaltime for x in action_profile.getstats()]) + total_duration = _get_python_cprofile_total_duration(action_profile) average_duration = total_duration / n_iter assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE -def test_advanced_profiler_describe(advanced_profiler): +def test_advanced_profiler_describe(tmpdir, advanced_profiler): """ ensure the profiler won't fail when reporting the summary """ + # record at least one event + with advanced_profiler.profile("test"): + pass + # log to stdout advanced_profiler.describe() + # print to file + advanced_profiler.output_filename = Path(tmpdir, "profiler.txt") + advanced_profiler.describe() + data = Path(advanced_profiler.output_filename).read_text() + assert len(data) > 0 + + +def test_advanced_profiler_value_errors(advanced_profiler): + """Ensure errors are raised where expected.""" + + action = "test" + with pytest.raises(ValueError): + advanced_profiler.stop(action) + + advanced_profiler.start(action) + advanced_profiler.stop(action)