new feature for profiling training runs (#782)

* initial implementation

* formatting, pass through profiler, docstring

* call profiler during training

* add initial tests

* report stats when training is done

* fix formatting

* error handling, bugfix in passthroughprofiler

* finish documenting profiler arg in Trainer

* relax required precision for profiling tests

* option to dump cProfiler results to text file

* use logging, format with black

* include profiler in docs

* improved logging and better docs

* appease the linter

* better summaries, wrapper for iterables

* fix typo

* allow profiler=True creation

* more documentation

* add tests for advanced profiler

* Update trainer.py

* make profilers accessible in pl.utilities

* reorg profiler files

* change import for profiler tests

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Jeremy Jordan 2020-02-06 22:01:21 -05:00 committed by GitHub
parent 57074b3268
commit 1cf430f7bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 423 additions and 19 deletions

View File

@ -13,9 +13,15 @@ gradient clipping
modifying training via hooks
=============================
.. toctree::
:maxdepth: 3
pl_examples
pl_examples
profiling a training run
========================
.. toctree::
:maxdepth: 1
profiler

10
docs/source/profiler.rst Normal file
View File

@ -0,0 +1,10 @@
.. role:: hidden
:class: hidden-section
Profiling performance during training
===========
.. automodule:: pytorch_lightning.profiler
:exclude-members:
_abc_impl,
summarize,

View File

@ -0,0 +1,112 @@
"""
Profiling your training run can help you understand if there are any bottlenecks in your code.
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
- on_epoch_start
- on_epoch_end
- on_batch_start
- tbptt_split_batch
- model_forward
- model_backward
- on_after_backward
- optimizer_step
- on_batch_end
- training_end
- on_training_end
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
.. code-block:: python
trainer = Trainer(..., profiler=True)
The profiler's results will be printed at the completion of a training `fit()`.
.. code-block:: python
Profiler Report
Action | Mean duration (s) | Total time (s)
-----------------------------------------------------------------
on_epoch_start | 5.993e-06 | 5.993e-06
get_train_batch | 0.0087412 | 16.398
on_batch_start | 5.0865e-06 | 0.0095372
model_forward | 0.0017818 | 3.3408
model_backward | 0.0018283 | 3.4282
on_after_backward | 4.2862e-06 | 0.0080366
optimizer_step | 0.0011072 | 2.0759
on_batch_end | 4.5202e-06 | 0.0084753
on_epoch_end | 3.919e-06 | 3.919e-06
on_train_end | 5.449e-06 | 5.449e-06
If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile
.. code-block:: python
profiler = AdvancedProfiler()
trainer = Trainer(..., profiler=profiler)
The profiler's results will be printed at the completion of a training `fit()`. This profiler
report can be quite long, so you can also specify an `output_filename` to save the report instead
of logging it to the output in your terminal. The output below shows the profiling for the action
`get_train_batch`.
.. code-block:: python
Profiler Report
Profile stats for: get_train_batch
4869394 function calls (4863767 primitive calls) in 18.893 seconds
Ordered by: cumulative time
List reduced from 76 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next}
1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__)
1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data)
1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch)
1875 0.084 0.000 18.290 0.010 fetch.py:44(<listcomp>)
60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__)
60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__)
60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__)
60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor)
60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__)
You can also reference this profiler in your LightningModule to profile specific actions of interest.
If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler`
which will allow you to skip profiling without having to make any code changes. Each profiler has a
method `profile()` which returns a context handler. Simply pass in the name of your action that you want
to track and the profiler will record performance for code executed within this context.
.. code-block:: python
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
class MyModel(LightningModule):
def __init__(self, hparams, profiler=None):
self.hparams = hparams
self.profiler = profiler or PassThroughProfiler()
def custom_processing_step(self, data):
with profiler.profile('my_custom_action'):
# custom processing step
return data
profiler = Profiler()
model = MyModel(hparams, profiler)
trainer = Trainer(profiler=profiler, max_epochs=1)
"""
from .profiler import Profiler, AdvancedProfiler, PassThroughProfiler
__all__ = [
'Profiler',
'AdvancedProfiler',
'PassThroughProfiler',
]

View File

@ -0,0 +1,181 @@
from contextlib import contextmanager
from collections import defaultdict
import time
import numpy as np
import cProfile
import pstats
import io
from abc import ABC, abstractmethod
import logging
logger = logging.getLogger(__name__)
class BaseProfiler(ABC):
"""
If you wish to write a custom profiler, you should inhereit from this class.
"""
@abstractmethod
def start(self, action_name):
"""
Defines how to start recording an action.
"""
pass
@abstractmethod
def stop(self, action_name):
"""
Defines how to record the duration once an action is complete.
"""
pass
@contextmanager
def profile(self, action_name):
"""
Yields a context manager to encapsulate the scope of a profiled action.
Example::
with self.profile('load training data'):
# load training data code
The profiler will start once you've entered the context and will automatically
stop once you exit the code block.
"""
try:
self.start(action_name)
yield action_name
finally:
self.stop(action_name)
def profile_iterable(self, iterable, action_name):
iterator = iter(iterable)
while True:
try:
self.start(action_name)
value = next(iterator)
self.stop(action_name)
yield value
except StopIteration:
self.stop(action_name)
break
def describe(self):
"""
Logs a profile report after the conclusion of the training run.
"""
pass
class PassThroughProfiler(BaseProfiler):
"""
This class should be used when you don't want the (small) overhead of profiling.
The Trainer uses this class by default.
"""
def __init__(self):
pass
def start(self, action_name):
pass
def stop(self, action_name):
pass
class Profiler(BaseProfiler):
"""
This profiler simply records the duration of actions (in seconds) and reports
the mean duration of each action and the total time spent over the entire training run.
"""
def __init__(self):
self.current_actions = {}
self.recorded_durations = defaultdict(list)
def start(self, action_name):
if action_name in self.current_actions:
raise ValueError(
f"Attempted to start {action_name} which has already started."
)
self.current_actions[action_name] = time.monotonic()
def stop(self, action_name):
end_time = time.monotonic()
if action_name not in self.current_actions:
raise ValueError(
f"Attempting to stop recording an action ({action_name}) which was never started."
)
start_time = self.current_actions.pop(action_name)
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)
def describe(self):
output_string = "\n\nProfiler Report\n"
def log_row(action, mean, total):
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"\n{'-' * 65}"
for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
)
output_string += "\n"
logger.info(output_string)
class AdvancedProfiler(BaseProfiler):
"""
This profiler uses Python's cProfiler to record more detailed information about
time spent in each function call recorded during a given action. The output is quite
verbose and you should only use this if you want very detailed reports.
"""
def __init__(self, output_filename=None, line_count_restriction=1.0):
"""
:param output_filename (str): optionally save profile results to file instead of printing
to std out when training is finished.
:param line_count_restriction (int|float): this can be used to limit the number of functions
reported for each action. either an integer (to select a count of lines),
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
"""
self.profiled_actions = {}
self.output_filename = output_filename
self.line_count_restriction = line_count_restriction
def start(self, action_name):
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
self.profiled_actions[action_name].enable()
def stop(self, action_name):
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError(
f"Attempting to stop recording an action ({action_name}) which was never started."
)
pr.disable()
def describe(self):
self.recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
s = io.StringIO()
sortby = pstats.SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby)
ps.print_stats(self.line_count_restriction)
self.recorded_stats[action_name] = s.getvalue()
if self.output_filename is not None:
# save to file
with open(self.output_filename, "w") as f:
for action, stats in self.recorded_stats.items():
f.write(f"Profile stats for: {action}")
f.write(stats)
else:
# log to standard out
output_string = "\nProfiler Report\n"
for action, stats in self.recorded_stats.items():
output_string += f"\nProfile stats for: {action}\n{stats}"
logger.info(output_string)

View File

@ -212,7 +212,7 @@ class TrainerEvaluationLoopMixin(ABC):
# bookkeeping
outputs = []
# run training
# run validation
for dataloader_idx, dataloader in enumerate(dataloaders):
dl_outputs = []
for batch_idx, batch in enumerate(dataloader):

View File

@ -26,6 +26,8 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
try:
from apex import amp
@ -87,6 +89,7 @@ class Trainer(TrainerIOMixin,
num_sanity_val_steps=5,
truncated_bptt_steps=None,
resume_from_checkpoint=None,
profiler=None
):
r"""
@ -460,6 +463,25 @@ class Trainer(TrainerIOMixin,
# resume from a specific checkpoint
trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
profiler (BaseProfiler): To profile individual steps during training and assist in
identifying bottlenecks.
Example::
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
# default used by the Trainer
trainer = Trainer(profiler=None)
# to profile standard training events
trainer = Trainer(profiler=True)
# equivalent to profiler=True
profiler = Profiler()
trainer = Trainer(profiler=profiler)
# advanced profiler for function-level stats
profiler = AdvancedProfiler()
trainer = Trainer(profiler=profiler)
.. warning:: Following arguments become deprecated and they will be removed in v0.8.0:
@ -564,6 +586,11 @@ class Trainer(TrainerIOMixin,
# configure logger
self.configure_logger(logger)
# configure profiler
if profiler is True:
profiler = Profiler()
self.profiler = profiler or PassThroughProfiler()
# configure early stop callback
# creates a default one if none passed in
self.configure_early_stopping(early_stop_callback)
@ -870,6 +897,9 @@ class Trainer(TrainerIOMixin,
# CORE TRAINING LOOP
self.train()
# summarize profile results
self.profiler.describe()
def test(self, model=None):
r"""

View File

@ -211,6 +211,7 @@ class TrainerTrainLoopMixin(ABC):
self.training_tqdm_dict = None
self.get_train_dataloader = None
self.reduce_lr_on_plateau_scheduler = None
self.profiler = None
@property
def max_nb_epochs(self):
@ -357,12 +358,14 @@ class TrainerTrainLoopMixin(ABC):
stop = should_stop and met_min_epochs
if stop:
self.main_progress_bar.close()
model.on_train_end()
with self.profiler.profile('on_train_end'):
model.on_train_end()
return
self.main_progress_bar.close()
model.on_train_end()
with self.profiler.profile('on_train_end'):
model.on_train_end()
if self.logger is not None:
self.logger.finalize("success")
@ -371,10 +374,13 @@ class TrainerTrainLoopMixin(ABC):
# before epoch hook
if self.is_function_implemented('on_epoch_start'):
model = self.get_model()
model.on_epoch_start()
with self.profiler.profile('on_epoch_start'):
model.on_epoch_start()
# run epoch
for batch_idx, batch in enumerate(self.get_train_dataloader()):
for batch_idx, batch in self.profiler.profile_iterable(
enumerate(self.get_train_dataloader()), "get_train_batch"
):
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
break
@ -432,7 +438,8 @@ class TrainerTrainLoopMixin(ABC):
# epoch end hook
if self.is_function_implemented('on_epoch_end'):
model = self.get_model()
model.on_epoch_end()
with self.profiler.profile('on_epoch_end'):
model.on_epoch_end()
def run_training_batch(self, batch, batch_idx):
# track grad norms
@ -450,7 +457,8 @@ class TrainerTrainLoopMixin(ABC):
# hook
if self.is_function_implemented('on_batch_start'):
model_ref = self.get_model()
response = model_ref.on_batch_start(batch)
with self.profiler.profile('on_batch_start'):
response = model_ref.on_batch_start(batch)
if response == -1:
return -1, grad_norm_dic, {}
@ -458,7 +466,8 @@ class TrainerTrainLoopMixin(ABC):
splits = [batch]
if self.truncated_bptt_steps is not None:
model_ref = self.get_model()
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
with self.profiler.profile('tbptt_split_batch'):
splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps)
self.hiddens = None
for split_idx, split_batch in enumerate(splits):
@ -478,8 +487,9 @@ class TrainerTrainLoopMixin(ABC):
# wrap the forward step in a closure so second order methods work
def optimizer_closure():
# forward pass
output = self.training_forward(
split_batch, batch_idx, opt_idx, self.hiddens)
with self.profiler.profile('model_forward'):
output = self.training_forward(
split_batch, batch_idx, opt_idx, self.hiddens)
closure_loss = output[0]
progress_bar_metrics = output[1]
@ -493,7 +503,8 @@ class TrainerTrainLoopMixin(ABC):
# backward pass
model_ref = self.get_model()
model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx)
with self.profiler.profile('model_backward'):
model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx)
# track metrics for callbacks
all_callback_metrics.append(callback_metrics)
@ -505,7 +516,8 @@ class TrainerTrainLoopMixin(ABC):
# insert after step hook
if self.is_function_implemented('on_after_backward'):
model_ref = self.get_model()
model_ref.on_after_backward()
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()
return closure_loss
@ -535,8 +547,9 @@ class TrainerTrainLoopMixin(ABC):
# calls .step(), .zero_grad()
# override function to modify this behavior
model = self.get_model()
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx, optimizer_closure)
with self.profiler.profile('optimizer_step'):
model.optimizer_step(self.current_epoch, batch_idx,
optimizer, opt_idx, optimizer_closure)
# calculate running loss for display
self.running_loss.append(self.batch_loss_value)
@ -546,7 +559,8 @@ class TrainerTrainLoopMixin(ABC):
# activate batch end hook
if self.is_function_implemented('on_batch_end'):
model = self.get_model()
model.on_batch_end()
with self.profiler.profile('on_batch_end'):
model.on_batch_end()
# update progress bar
self.main_progress_bar.update(1)
@ -606,7 +620,8 @@ class TrainerTrainLoopMixin(ABC):
# allow any mode to define training_end
if self.is_overriden('training_end'):
model_ref = self.get_model()
output = model_ref.training_end(output)
with self.profiler.profile('training_end'):
output = model_ref.training_end(output)
# format and reduce outputs accordingly
output = self.process_output(output, train=True)

50
tests/test_profiler.py Normal file
View File

@ -0,0 +1,50 @@
from pytorch_lightning.profiler import Profiler, AdvancedProfiler
import time
import numpy as np
def test_simple_profiler():
p = Profiler()
with p.profile("a"):
time.sleep(3)
with p.profile("a"):
time.sleep(1)
with p.profile("b"):
time.sleep(2)
with p.profile("c"):
time.sleep(1)
# different environments have different precision when it comes to time.sleep()
np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1)
np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1)
np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1)
def test_advanced_profiler():
def get_duration(profile):
return sum([x.totaltime for x in profile.getstats()])
p = AdvancedProfiler()
with p.profile("a"):
time.sleep(3)
with p.profile("a"):
time.sleep(1)
with p.profile("b"):
time.sleep(2)
with p.profile("c"):
time.sleep(1)
a_duration = get_duration(p.profiled_actions["a"])
np.testing.assert_almost_equal(a_duration, [4], decimal=1)
b_duration = get_duration(p.profiled_actions["b"])
np.testing.assert_almost_equal(b_duration, [2], decimal=1)
c_duration = get_duration(p.profiled_actions["c"])
np.testing.assert_almost_equal(c_duration, [1], decimal=1)