lightning/tests/test_profiler.py

189 lines
6.0 KiB
Python
Raw Normal View History

2020-12-15 17:59:13 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from pathlib import Path
2020-03-12 16:41:37 +00:00
import numpy as np
import pytest
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005
def _get_python_cprofile_total_duration(profile):
return sum([x.inlinetime for x in profile.getstats()])
def _sleep_generator(durations):
"""
the profile_iterable method needs an iterable in which we can ensure that we're
properly timing how long it takes to call __next__
"""
for duration in durations:
time.sleep(duration)
yield duration
@pytest.fixture
def simple_profiler():
profiler = SimpleProfiler()
return profiler
@pytest.fixture
def advanced_profiler(tmpdir):
profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"))
return profiler
@pytest.mark.parametrize(["action", "expected"], [
pytest.param("a", [3, 1]),
pytest.param("b", [2]),
pytest.param("c", [1]),
])
def test_simple_profiler_durations(simple_profiler, action, expected):
"""Ensure the reported durations are reasonably accurate."""
for duration in expected:
with simple_profiler.profile(action):
time.sleep(duration)
# different environments have different precision when it comes to time.sleep()
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2)
@pytest.mark.parametrize(["action", "expected"], [
pytest.param("a", [3, 1]),
pytest.param("b", [2]),
pytest.param("c", [1]),
])
def test_simple_profiler_iterable_durations(simple_profiler, action, expected):
"""Ensure the reported durations are reasonably accurate."""
iterable = _sleep_generator(expected)
for _ in simple_profiler.profile_iterable(iterable, action):
pass
# we exclude the last item in the recorded durations since that's when StopIteration is raised
np.testing.assert_allclose(simple_profiler.recorded_durations[action][:-1], expected, rtol=0.2)
def test_simple_profiler_overhead(simple_profiler, n_iter=5):
"""Ensure that the profiler doesn't introduce too much overhead during training."""
for _ in range(n_iter):
with simple_profiler.profile("no-op"):
pass
durations = np.array(simple_profiler.recorded_durations["no-op"])
assert all(durations < PROFILER_OVERHEAD_MAX_TOLERANCE)
def test_simple_profiler_describe(caplog, simple_profiler):
"""Ensure the profiler won't fail when reporting the summary."""
simple_profiler.describe()
assert "Profiler Report" in caplog.text
def test_simple_profiler_value_errors(simple_profiler):
"""Ensure errors are raised where expected."""
action = "test"
with pytest.raises(ValueError):
simple_profiler.stop(action)
simple_profiler.start(action)
with pytest.raises(ValueError):
simple_profiler.start(action)
simple_profiler.stop(action)
@pytest.mark.parametrize(["action", "expected"], [
pytest.param("a", [3, 1]),
pytest.param("b", [2]),
pytest.param("c", [1]),
])
def test_advanced_profiler_durations(advanced_profiler, action, expected):
for duration in expected:
with advanced_profiler.profile(action):
time.sleep(duration)
# different environments have different precision when it comes to time.sleep()
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
recored_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
expected_total_duration = np.sum(expected)
np.testing.assert_allclose(recored_total_duration, expected_total_duration, rtol=0.2)
@pytest.mark.parametrize(["action", "expected"], [
pytest.param("a", [3, 1]),
pytest.param("b", [2]),
pytest.param("c", [1]),
])
def test_advanced_profiler_iterable_durations(advanced_profiler, action, expected):
"""Ensure the reported durations are reasonably accurate."""
iterable = _sleep_generator(expected)
for _ in advanced_profiler.profile_iterable(iterable, action):
pass
recored_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action])
expected_total_duration = np.sum(expected)
np.testing.assert_allclose(recored_total_duration, expected_total_duration, rtol=0.2)
def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
"""
ensure that the profiler doesn't introduce too much overhead during training
"""
for _ in range(n_iter):
with advanced_profiler.profile("no-op"):
pass
action_profile = advanced_profiler.profiled_actions["no-op"]
total_duration = _get_python_cprofile_total_duration(action_profile)
average_duration = total_duration / n_iter
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
resolving documentation warnings (#833) * add more underline * fix LightningMudule import error * remove unneeded blank line * escape asterisk to fix inline emphasis warning * add PULL_REQUEST_TEMPLATE.md * add __init__.py and import imagenet_example * fix duplicate label * add noindex option to fix duplicate object warnings * remove unexpected indent * refer explicit LightningModule * fix minor bug * refer EarlyStopping explicitly * restore exclude patterns * change the way how to refer class * remove unused import * update badges & drop Travis/Appveyor (#826) * drop Travis * drop Appveyor * update badges * fix missing PyPI images & CI badges (#853) * docs - anchor links (#848) * docs - add links * add desc. * add Greeting action (#843) * add Greeting action * Update greetings.yml Co-authored-by: William Falcon <waf2107@columbia.edu> * add pep8speaks (#842) * advanced profiler describe + cleaned up tests (#837) * add py36 compatibility * add test case to capture previous bug * clean up tests * clean up tests * Update lightning_module_template.py * Update lightning.py * respond lint issues * break long line * break more lines * checkout conflicting files from master * shorten url * checkout from upstream/master * remove trailing whitespaces * remove unused import LightningModule * fix sphinx bot warnings * Apply suggestions from code review just to trigger CI * Update .github/workflows/greetings.yml Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-02-27 21:07:51 +00:00
"""
ensure the profiler won't fail when reporting the summary
"""
# record at least one event
with advanced_profiler.profile("test"):
pass
# log to stdout and print to file
advanced_profiler.describe()
data = Path(advanced_profiler.output_fname).read_text()
assert len(data) > 0
def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"
with pytest.raises(ValueError):
advanced_profiler.stop(action)
advanced_profiler.start(action)
advanced_profiler.stop(action)