51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
|
import time
|
|
import numpy as np
|
|
|
|
|
|
def test_simple_profiler():
|
|
p = Profiler()
|
|
|
|
with p.profile("a"):
|
|
time.sleep(3)
|
|
|
|
with p.profile("a"):
|
|
time.sleep(1)
|
|
|
|
with p.profile("b"):
|
|
time.sleep(2)
|
|
|
|
with p.profile("c"):
|
|
time.sleep(1)
|
|
|
|
# different environments have different precision when it comes to time.sleep()
|
|
np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1)
|
|
np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1)
|
|
np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1)
|
|
|
|
|
|
def test_advanced_profiler():
|
|
def get_duration(profile):
|
|
return sum([x.totaltime for x in profile.getstats()])
|
|
|
|
p = AdvancedProfiler()
|
|
|
|
with p.profile("a"):
|
|
time.sleep(3)
|
|
|
|
with p.profile("a"):
|
|
time.sleep(1)
|
|
|
|
with p.profile("b"):
|
|
time.sleep(2)
|
|
|
|
with p.profile("c"):
|
|
time.sleep(1)
|
|
|
|
a_duration = get_duration(p.profiled_actions["a"])
|
|
np.testing.assert_almost_equal(a_duration, [4], decimal=1)
|
|
b_duration = get_duration(p.profiled_actions["b"])
|
|
np.testing.assert_almost_equal(b_duration, [2], decimal=1)
|
|
c_duration = get_duration(p.profiled_actions["c"])
|
|
np.testing.assert_almost_equal(c_duration, [1], decimal=1)
|