fix test for profiler (#800)
* fix test for profiler * use allclose * user relative tol
This commit is contained in:
parent
5130841bef
commit
fc0ad03008
|
@ -1,7 +1,9 @@
|
|||
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
||||
|
||||
|
||||
def test_simple_profiler():
|
||||
p = Profiler()
|
||||
|
@ -19,13 +21,14 @@ def test_simple_profiler():
|
|||
time.sleep(1)
|
||||
|
||||
# different environments have different precision when it comes to time.sleep()
|
||||
np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1)
|
||||
np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1)
|
||||
np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1)
|
||||
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
|
||||
np.testing.assert_allclose(p.recorded_durations["a"], [3, 1], rtol=0.2)
|
||||
np.testing.assert_allclose(p.recorded_durations["b"], [2], rtol=0.2)
|
||||
np.testing.assert_allclose(p.recorded_durations["c"], [1], rtol=0.2)
|
||||
|
||||
|
||||
def test_advanced_profiler():
|
||||
def get_duration(profile):
|
||||
def _get_duration(profile):
|
||||
return sum([x.totaltime for x in profile.getstats()])
|
||||
|
||||
p = AdvancedProfiler()
|
||||
|
@ -42,9 +45,11 @@ def test_advanced_profiler():
|
|||
with p.profile("c"):
|
||||
time.sleep(1)
|
||||
|
||||
a_duration = get_duration(p.profiled_actions["a"])
|
||||
np.testing.assert_almost_equal(a_duration, [4], decimal=1)
|
||||
b_duration = get_duration(p.profiled_actions["b"])
|
||||
np.testing.assert_almost_equal(b_duration, [2], decimal=1)
|
||||
c_duration = get_duration(p.profiled_actions["c"])
|
||||
np.testing.assert_almost_equal(c_duration, [1], decimal=1)
|
||||
# different environments have different precision when it comes to time.sleep()
|
||||
# see: https://github.com/PyTorchLightning/pytorch-lightning/issues/796
|
||||
a_duration = _get_duration(p.profiled_actions["a"])
|
||||
np.testing.assert_allclose(a_duration, [4], rtol=0.2)
|
||||
b_duration = _get_duration(p.profiled_actions["b"])
|
||||
np.testing.assert_allclose(b_duration, [2], rtol=0.2)
|
||||
c_duration = _get_duration(p.profiled_actions["c"])
|
||||
np.testing.assert_allclose(c_duration, [1], rtol=0.2)
|
||||
|
|
Loading…
Reference in New Issue