new feature for profiling training runs (#782)
* initial implementation * formatting, pass through profiler, docstring * call profiler during training * add initial tests * report stats when training is done * fix formatting * error handling, bugfix in passthroughprofiler * finish documenting profiler arg in Trainer * relax required precision for profiling tests * option to dump cProfiler results to text file * use logging, format with black * include profiler in docs * improved logging and better docs * appease the linter * better summaries, wrapper for iterables * fix typo * allow profiler=True creation * more documentation * add tests for advanced profiler * Update trainer.py * make profilers accessible in pl.utilities * reorg profiler files * change import for profiler tests Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
57074b3268
commit
1cf430f7bc
|
@ -13,9 +13,15 @@ gradient clipping
|
|||
modifying training via hooks
|
||||
=============================
|
||||
|
||||
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
pl_examples
|
||||
pl_examples
|
||||
|
||||
|
||||
profiling a training run
|
||||
========================
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
profiler
|
|
@ -0,0 +1,10 @@
|
|||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
|
||||
Profiling performance during training
|
||||
===========
|
||||
.. automodule:: pytorch_lightning.profiler
|
||||
:exclude-members:
|
||||
_abc_impl,
|
||||
summarize,
|
|
@ -0,0 +1,112 @@
|
|||
"""
|
||||
Profiling your training run can help you understand if there are any bottlenecks in your code.
|
||||
|
||||
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
|
||||
|
||||
- on_epoch_start
|
||||
- on_epoch_end
|
||||
- on_batch_start
|
||||
- tbptt_split_batch
|
||||
- model_forward
|
||||
- model_backward
|
||||
- on_after_backward
|
||||
- optimizer_step
|
||||
- on_batch_end
|
||||
- training_end
|
||||
- on_training_end
|
||||
|
||||
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
|
||||
your `Trainer` object.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(..., profiler=True)
|
||||
|
||||
The profiler's results will be printed at the completion of a training `fit()`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Profiler Report
|
||||
|
||||
Action | Mean duration (s) | Total time (s)
|
||||
-----------------------------------------------------------------
|
||||
on_epoch_start | 5.993e-06 | 5.993e-06
|
||||
get_train_batch | 0.0087412 | 16.398
|
||||
on_batch_start | 5.0865e-06 | 0.0095372
|
||||
model_forward | 0.0017818 | 3.3408
|
||||
model_backward | 0.0018283 | 3.4282
|
||||
on_after_backward | 4.2862e-06 | 0.0080366
|
||||
optimizer_step | 0.0011072 | 2.0759
|
||||
on_batch_end | 4.5202e-06 | 0.0084753
|
||||
on_epoch_end | 3.919e-06 | 3.919e-06
|
||||
on_train_end | 5.449e-06 | 5.449e-06
|
||||
|
||||
|
||||
If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
|
||||
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
|
||||
|
||||
.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
profiler = AdvancedProfiler()
|
||||
trainer = Trainer(..., profiler=profiler)
|
||||
|
||||
The profiler's results will be printed at the completion of a training `fit()`. This profiler
|
||||
report can be quite long, so you can also specify an `output_filename` to save the report instead
|
||||
of logging it to the output in your terminal. The output below shows the profiling for the action
|
||||
`get_train_batch`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Profiler Report
|
||||
|
||||
Profile stats for: get_train_batch
|
||||
4869394 function calls (4863767 primitive calls) in 18.893 seconds
|
||||
Ordered by: cumulative time
|
||||
List reduced from 76 to 10 due to restriction <10>
|
||||
ncalls tottime percall cumtime percall filename:lineno(function)
|
||||
3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next}
|
||||
1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__)
|
||||
1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data)
|
||||
1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch)
|
||||
1875 0.084 0.000 18.290 0.010 fetch.py:44(<listcomp>)
|
||||
60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__)
|
||||
60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__)
|
||||
60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__)
|
||||
60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor)
|
||||
60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__)
|
||||
|
||||
You can also reference this profiler in your LightningModule to profile specific actions of interest.
|
||||
If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler`
|
||||
which will allow you to skip profiling without having to make any code changes. Each profiler has a
|
||||
method `profile()` which returns a context handler. Simply pass in the name of your action that you want
|
||||
to track and the profiler will record performance for code executed within this context.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def __init__(self, hparams, profiler=None):
|
||||
self.hparams = hparams
|
||||
self.profiler = profiler or PassThroughProfiler()
|
||||
|
||||
def custom_processing_step(self, data):
|
||||
with profiler.profile('my_custom_action'):
|
||||
# custom processing step
|
||||
return data
|
||||
|
||||
profiler = Profiler()
|
||||
model = MyModel(hparams, profiler)
|
||||
trainer = Trainer(profiler=profiler, max_epochs=1)
|
||||
|
||||
"""
|
||||
|
||||
from .profiler import Profiler, AdvancedProfiler, PassThroughProfiler
|
||||
|
||||
__all__ = [
|
||||
'Profiler',
|
||||
'AdvancedProfiler',
|
||||
'PassThroughProfiler',
|
||||
]
|
|
@ -0,0 +1,181 @@
|
|||
from contextlib import contextmanager
|
||||
from collections import defaultdict
|
||||
import time
|
||||
import numpy as np
|
||||
import cProfile
|
||||
import pstats
|
||||
import io
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseProfiler(ABC):
|
||||
"""
|
||||
If you wish to write a custom profiler, you should inhereit from this class.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def start(self, action_name):
|
||||
"""
|
||||
Defines how to start recording an action.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(self, action_name):
|
||||
"""
|
||||
Defines how to record the duration once an action is complete.
|
||||
"""
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def profile(self, action_name):
|
||||
"""
|
||||
Yields a context manager to encapsulate the scope of a profiled action.
|
||||
|
||||
Example::
|
||||
|
||||
with self.profile('load training data'):
|
||||
# load training data code
|
||||
|
||||
The profiler will start once you've entered the context and will automatically
|
||||
stop once you exit the code block.
|
||||
"""
|
||||
try:
|
||||
self.start(action_name)
|
||||
yield action_name
|
||||
finally:
|
||||
self.stop(action_name)
|
||||
|
||||
def profile_iterable(self, iterable, action_name):
|
||||
iterator = iter(iterable)
|
||||
while True:
|
||||
try:
|
||||
self.start(action_name)
|
||||
value = next(iterator)
|
||||
self.stop(action_name)
|
||||
yield value
|
||||
except StopIteration:
|
||||
self.stop(action_name)
|
||||
break
|
||||
|
||||
def describe(self):
|
||||
"""
|
||||
Logs a profile report after the conclusion of the training run.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PassThroughProfiler(BaseProfiler):
|
||||
"""
|
||||
This class should be used when you don't want the (small) overhead of profiling.
|
||||
The Trainer uses this class by default.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def start(self, action_name):
|
||||
pass
|
||||
|
||||
def stop(self, action_name):
|
||||
pass
|
||||
|
||||
|
||||
class Profiler(BaseProfiler):
|
||||
"""
|
||||
This profiler simply records the duration of actions (in seconds) and reports
|
||||
the mean duration of each action and the total time spent over the entire training run.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_actions = {}
|
||||
self.recorded_durations = defaultdict(list)
|
||||
|
||||
def start(self, action_name):
|
||||
if action_name in self.current_actions:
|
||||
raise ValueError(
|
||||
f"Attempted to start {action_name} which has already started."
|
||||
)
|
||||
self.current_actions[action_name] = time.monotonic()
|
||||
|
||||
def stop(self, action_name):
|
||||
end_time = time.monotonic()
|
||||
if action_name not in self.current_actions:
|
||||
raise ValueError(
|
||||
f"Attempting to stop recording an action ({action_name}) which was never started."
|
||||
)
|
||||
start_time = self.current_actions.pop(action_name)
|
||||
duration = end_time - start_time
|
||||
self.recorded_durations[action_name].append(duration)
|
||||
|
||||
def describe(self):
|
||||
output_string = "\n\nProfiler Report\n"
|
||||
|
||||
def log_row(action, mean, total):
|
||||
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
|
||||
|
||||
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
|
||||
output_string += f"\n{'-' * 65}"
|
||||
for action, durations in self.recorded_durations.items():
|
||||
output_string += log_row(
|
||||
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
|
||||
)
|
||||
output_string += "\n"
|
||||
logger.info(output_string)
|
||||
|
||||
|
||||
class AdvancedProfiler(BaseProfiler):
|
||||
"""
|
||||
This profiler uses Python's cProfiler to record more detailed information about
|
||||
time spent in each function call recorded during a given action. The output is quite
|
||||
verbose and you should only use this if you want very detailed reports.
|
||||
"""
|
||||
|
||||
def __init__(self, output_filename=None, line_count_restriction=1.0):
|
||||
"""
|
||||
:param output_filename (str): optionally save profile results to file instead of printing
|
||||
to std out when training is finished.
|
||||
:param line_count_restriction (int|float): this can be used to limit the number of functions
|
||||
reported for each action. either an integer (to select a count of lines),
|
||||
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
|
||||
"""
|
||||
self.profiled_actions = {}
|
||||
self.output_filename = output_filename
|
||||
self.line_count_restriction = line_count_restriction
|
||||
|
||||
def start(self, action_name):
|
||||
if action_name not in self.profiled_actions:
|
||||
self.profiled_actions[action_name] = cProfile.Profile()
|
||||
self.profiled_actions[action_name].enable()
|
||||
|
||||
def stop(self, action_name):
|
||||
pr = self.profiled_actions.get(action_name)
|
||||
if pr is None:
|
||||
raise ValueError(
|
||||
f"Attempting to stop recording an action ({action_name}) which was never started."
|
||||
)
|
||||
pr.disable()
|
||||
|
||||
def describe(self):
|
||||
self.recorded_stats = {}
|
||||
for action_name, pr in self.profiled_actions.items():
|
||||
s = io.StringIO()
|
||||
sortby = pstats.SortKey.CUMULATIVE
|
||||
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby)
|
||||
ps.print_stats(self.line_count_restriction)
|
||||
self.recorded_stats[action_name] = s.getvalue()
|
||||
if self.output_filename is not None:
|
||||
# save to file
|
||||
with open(self.output_filename, "w") as f:
|
||||
for action, stats in self.recorded_stats.items():
|
||||
f.write(f"Profile stats for: {action}")
|
||||
f.write(stats)
|
||||
else:
|
||||
# log to standard out
|
||||
output_string = "\nProfiler Report\n"
|
||||
for action, stats in self.recorded_stats.items():
|
||||
output_string += f"\nProfile stats for: {action}\n{stats}"
|
||||
logger.info(output_string)
|
|
@ -212,7 +212,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# bookkeeping
|
||||
outputs = []
|
||||
|
||||
# run training
|
||||
# run validation
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
dl_outputs = []
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
|
|
|
@ -26,6 +26,8 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
|||
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
||||
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
|
||||
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -87,6 +89,7 @@ class Trainer(TrainerIOMixin,
|
|||
num_sanity_val_steps=5,
|
||||
truncated_bptt_steps=None,
|
||||
resume_from_checkpoint=None,
|
||||
profiler=None
|
||||
):
|
||||
r"""
|
||||
|
||||
|
@ -460,6 +463,25 @@ class Trainer(TrainerIOMixin,
|
|||
|
||||
# resume from a specific checkpoint
|
||||
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
|
||||
profiler (BaseProfiler): To profile individual steps during training and assist in
|
||||
identifying bottlenecks.
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(profiler=None)
|
||||
|
||||
# to profile standard training events
|
||||
trainer = Trainer(profiler=True)
|
||||
|
||||
# equivalent to profiler=True
|
||||
profiler = Profiler()
|
||||
trainer = Trainer(profiler=profiler)
|
||||
|
||||
# advanced profiler for function-level stats
|
||||
profiler = AdvancedProfiler()
|
||||
trainer = Trainer(profiler=profiler)
|
||||
|
||||
.. warning:: Following arguments become deprecated and they will be removed in v0.8.0:
|
||||
|
||||
|
@ -564,6 +586,11 @@ class Trainer(TrainerIOMixin,
|
|||
# configure logger
|
||||
self.configure_logger(logger)
|
||||
|
||||
# configure profiler
|
||||
if profiler is True:
|
||||
profiler = Profiler()
|
||||
self.profiler = profiler or PassThroughProfiler()
|
||||
|
||||
# configure early stop callback
|
||||
# creates a default one if none passed in
|
||||
self.configure_early_stopping(early_stop_callback)
|
||||
|
@ -870,6 +897,9 @@ class Trainer(TrainerIOMixin,
|
|||
# CORE TRAINING LOOP
|
||||
self.train()
|
||||
|
||||
# summarize profile results
|
||||
self.profiler.describe()
|
||||
|
||||
def test(self, model=None):
|
||||
r"""
|
||||
|
||||
|
|
|
@ -211,6 +211,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
self.training_tqdm_dict = None
|
||||
self.get_train_dataloader = None
|
||||
self.reduce_lr_on_plateau_scheduler = None
|
||||
self.profiler = None
|
||||
|
||||
@property
|
||||
def max_nb_epochs(self):
|
||||
|
@ -357,12 +358,14 @@ class TrainerTrainLoopMixin(ABC):
|
|||
stop = should_stop and met_min_epochs
|
||||
if stop:
|
||||
self.main_progress_bar.close()
|
||||
model.on_train_end()
|
||||
with self.profiler.profile('on_train_end'):
|
||||
model.on_train_end()
|
||||
return
|
||||
|
||||
self.main_progress_bar.close()
|
||||
|
||||
model.on_train_end()
|
||||
with self.profiler.profile('on_train_end'):
|
||||
model.on_train_end()
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.finalize("success")
|
||||
|
@ -371,10 +374,13 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# before epoch hook
|
||||
if self.is_function_implemented('on_epoch_start'):
|
||||
model = self.get_model()
|
||||
model.on_epoch_start()
|
||||
with self.profiler.profile('on_epoch_start'):
|
||||
model.on_epoch_start()
|
||||
|
||||
# run epoch
|
||||
for batch_idx, batch in enumerate(self.get_train_dataloader()):
|
||||
for batch_idx, batch in self.profiler.profile_iterable(
|
||||
enumerate(self.get_train_dataloader()), "get_train_batch"
|
||||
):
|
||||
# stop epoch if we limited the number of training batches
|
||||
if batch_idx >= self.num_training_batches:
|
||||
break
|
||||
|
@ -432,7 +438,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# epoch end hook
|
||||
if self.is_function_implemented('on_epoch_end'):
|
||||
model = self.get_model()
|
||||
model.on_epoch_end()
|
||||
with self.profiler.profile('on_epoch_end'):
|
||||
model.on_epoch_end()
|
||||
|
||||
def run_training_batch(self, batch, batch_idx):
|
||||
# track grad norms
|
||||
|
@ -450,7 +457,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# hook
|
||||
if self.is_function_implemented('on_batch_start'):
|
||||
model_ref = self.get_model()
|
||||
response = model_ref.on_batch_start(batch)
|
||||
with self.profiler.profile('on_batch_start'):
|
||||
response = model_ref.on_batch_start(batch)
|
||||
|
||||
if response == -1:
|
||||
return -1, grad_norm_dic, {}
|
||||
|
@ -458,7 +466,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
splits = [batch]
|
||||
if self.truncated_bptt_steps is not None:
|
||||
model_ref = self.get_model()
|
||||
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
|
||||
with self.profiler.profile('tbptt_split_batch'):
|
||||
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
|
||||
|
||||
self.hiddens = None
|
||||
for split_idx, split_batch in enumerate(splits):
|
||||
|
@ -478,8 +487,9 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# wrap the forward step in a closure so second order methods work
|
||||
def optimizer_closure():
|
||||
# forward pass
|
||||
output = self.training_forward(
|
||||
split_batch, batch_idx, opt_idx, self.hiddens)
|
||||
with self.profiler.profile('model_forward'):
|
||||
output = self.training_forward(
|
||||
split_batch, batch_idx, opt_idx, self.hiddens)
|
||||
|
||||
closure_loss = output[0]
|
||||
progress_bar_metrics = output[1]
|
||||
|
@ -493,7 +503,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# backward pass
|
||||
model_ref = self.get_model()
|
||||
model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx)
|
||||
with self.profiler.profile('model_backward'):
|
||||
model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx)
|
||||
|
||||
# track metrics for callbacks
|
||||
all_callback_metrics.append(callback_metrics)
|
||||
|
@ -505,7 +516,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# insert after step hook
|
||||
if self.is_function_implemented('on_after_backward'):
|
||||
model_ref = self.get_model()
|
||||
model_ref.on_after_backward()
|
||||
with self.profiler.profile('on_after_backward'):
|
||||
model_ref.on_after_backward()
|
||||
|
||||
return closure_loss
|
||||
|
||||
|
@ -535,8 +547,9 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# calls .step(), .zero_grad()
|
||||
# override function to modify this behavior
|
||||
model = self.get_model()
|
||||
model.optimizer_step(self.current_epoch, batch_idx,
|
||||
optimizer, opt_idx, optimizer_closure)
|
||||
with self.profiler.profile('optimizer_step'):
|
||||
model.optimizer_step(self.current_epoch, batch_idx,
|
||||
optimizer, opt_idx, optimizer_closure)
|
||||
|
||||
# calculate running loss for display
|
||||
self.running_loss.append(self.batch_loss_value)
|
||||
|
@ -546,7 +559,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# activate batch end hook
|
||||
if self.is_function_implemented('on_batch_end'):
|
||||
model = self.get_model()
|
||||
model.on_batch_end()
|
||||
with self.profiler.profile('on_batch_end'):
|
||||
model.on_batch_end()
|
||||
|
||||
# update progress bar
|
||||
self.main_progress_bar.update(1)
|
||||
|
@ -606,7 +620,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
# allow any mode to define training_end
|
||||
if self.is_overriden('training_end'):
|
||||
model_ref = self.get_model()
|
||||
output = model_ref.training_end(output)
|
||||
with self.profiler.profile('training_end'):
|
||||
output = model_ref.training_end(output)
|
||||
|
||||
# format and reduce outputs accordingly
|
||||
output = self.process_output(output, train=True)
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_simple_profiler():
|
||||
p = Profiler()
|
||||
|
||||
with p.profile("a"):
|
||||
time.sleep(3)
|
||||
|
||||
with p.profile("a"):
|
||||
time.sleep(1)
|
||||
|
||||
with p.profile("b"):
|
||||
time.sleep(2)
|
||||
|
||||
with p.profile("c"):
|
||||
time.sleep(1)
|
||||
|
||||
# different environments have different precision when it comes to time.sleep()
|
||||
np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1)
|
||||
np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1)
|
||||
np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1)
|
||||
|
||||
|
||||
def test_advanced_profiler():
|
||||
def get_duration(profile):
|
||||
return sum([x.totaltime for x in profile.getstats()])
|
||||
|
||||
p = AdvancedProfiler()
|
||||
|
||||
with p.profile("a"):
|
||||
time.sleep(3)
|
||||
|
||||
with p.profile("a"):
|
||||
time.sleep(1)
|
||||
|
||||
with p.profile("b"):
|
||||
time.sleep(2)
|
||||
|
||||
with p.profile("c"):
|
||||
time.sleep(1)
|
||||
|
||||
a_duration = get_duration(p.profiled_actions["a"])
|
||||
np.testing.assert_almost_equal(a_duration, [4], decimal=1)
|
||||
b_duration = get_duration(p.profiled_actions["b"])
|
||||
np.testing.assert_almost_equal(b_duration, [2], decimal=1)
|
||||
c_duration = get_duration(p.profiled_actions["c"])
|
||||
np.testing.assert_almost_equal(c_duration, [1], decimal=1)
|
Loading…
Reference in New Issue