diff --git a/tests/test_profiler.py b/tests/test_profiler.py index d6e085a55e..1b26030de8 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,7 +1,9 @@ -from pytorch_lightning.profiler import Profiler, AdvancedProfiler import time + import numpy as np +from pytorch_lightning.profiler import Profiler, AdvancedProfiler + def test_simple_profiler(): p = Profiler() @@ -19,13 +21,14 @@ def test_simple_profiler(): time.sleep(1) # different environments have different precision when it comes to time.sleep() - np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1) - np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1) - np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1) + # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 + np.testing.assert_allclose(p.recorded_durations["a"], [3, 1], rtol=0.2) + np.testing.assert_allclose(p.recorded_durations["b"], [2], rtol=0.2) + np.testing.assert_allclose(p.recorded_durations["c"], [1], rtol=0.2) def test_advanced_profiler(): - def get_duration(profile): + def _get_duration(profile): return sum([x.totaltime for x in profile.getstats()]) p = AdvancedProfiler() @@ -42,9 +45,11 @@ def test_advanced_profiler(): with p.profile("c"): time.sleep(1) - a_duration = get_duration(p.profiled_actions["a"]) - np.testing.assert_almost_equal(a_duration, [4], decimal=1) - b_duration = get_duration(p.profiled_actions["b"]) - np.testing.assert_almost_equal(b_duration, [2], decimal=1) - c_duration = get_duration(p.profiled_actions["c"]) - np.testing.assert_almost_equal(c_duration, [1], decimal=1) + # different environments have different precision when it comes to time.sleep() + # see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796 + a_duration = _get_duration(p.profiled_actions["a"]) + np.testing.assert_allclose(a_duration, [4], rtol=0.2) + b_duration = _get_duration(p.profiled_actions["b"]) + np.testing.assert_allclose(b_duration, [2], rtol=0.2) + c_duration = _get_duration(p.profiled_actions["c"]) + np.testing.assert_allclose(c_duration, [1], rtol=0.2)