diff --git a/docs/source/common-cases.rst b/docs/source/common-cases.rst index 7b96a93d84..cc4ca362fc 100644 --- a/docs/source/common-cases.rst +++ b/docs/source/common-cases.rst @@ -13,9 +13,15 @@ gradient clipping modifying training via hooks ============================= - - .. toctree:: :maxdepth: 3 - pl_examples \ No newline at end of file + pl_examples + + +profiling a training run +======================== +.. toctree:: + :maxdepth: 1 + + profiler \ No newline at end of file diff --git a/docs/source/profiler.rst b/docs/source/profiler.rst new file mode 100644 index 0000000000..6443e7ddbc --- /dev/null +++ b/docs/source/profiler.rst @@ -0,0 +1,10 @@ +.. role:: hidden + :class: hidden-section + + +Profiling performance during training +=========== +.. automodule:: pytorch_lightning.profiler + :exclude-members: + _abc_impl, + summarize, diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py new file mode 100644 index 0000000000..a69e3ccf9c --- /dev/null +++ b/pytorch_lightning/profiler/__init__.py @@ -0,0 +1,112 @@ +""" +Profiling your training run can help you understand if there are any bottlenecks in your code. + +PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: + +- on_epoch_start +- on_epoch_end +- on_batch_start +- tbptt_split_batch +- model_forward +- model_backward +- on_after_backward +- optimizer_step +- on_batch_end +- training_end +- on_training_end + +If you only wish to profile the standard actions, you can set `profiler=True` when constructing +your `Trainer` object. + +.. code-block:: python + + trainer = Trainer(..., profiler=True) + +The profiler's results will be printed at the completion of a training `fit()`. + +.. code-block:: python + + Profiler Report + + Action | Mean duration (s) | Total time (s) + ----------------------------------------------------------------- + on_epoch_start | 5.993e-06 | 5.993e-06 + get_train_batch | 0.0087412 | 16.398 + on_batch_start | 5.0865e-06 | 0.0095372 + model_forward | 0.0017818 | 3.3408 + model_backward | 0.0018283 | 3.4282 + on_after_backward | 4.2862e-06 | 0.0080366 + optimizer_step | 0.0011072 | 2.0759 + on_batch_end | 4.5202e-06 | 0.0084753 + on_epoch_end | 3.919e-06 | 3.919e-06 + on_train_end | 5.449e-06 | 5.449e-06 + + +If you want more information on the functions called during each event, you can use the `AdvancedProfiler`. +This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code. + +.. _cProfiler: https://docs.python.org/3/library/profile.html#module-cProfile + +.. code-block:: python + + profiler = AdvancedProfiler() + trainer = Trainer(..., profiler=profiler) + +The profiler's results will be printed at the completion of a training `fit()`. This profiler +report can be quite long, so you can also specify an `output_filename` to save the report instead +of logging it to the output in your terminal. The output below shows the profiling for the action +`get_train_batch`. + +.. code-block:: python + + Profiler Report + + Profile stats for: get_train_batch + 4869394 function calls (4863767 primitive calls) in 18.893 seconds + Ordered by: cumulative time + List reduced from 76 to 10 due to restriction <10> + ncalls tottime percall cumtime percall filename:lineno(function) + 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} + 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) + 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) + 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) + 1875 0.084 0.000 18.290 0.010 fetch.py:44() + 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) + 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) + 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) + 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) + 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) + +You can also reference this profiler in your LightningModule to profile specific actions of interest. +If you don't want to always have the profiler turned on, you can optionally pass a `PassThroughProfiler` +which will allow you to skip profiling without having to make any code changes. Each profiler has a +method `profile()` which returns a context handler. Simply pass in the name of your action that you want +to track and the profiler will record performance for code executed within this context. + +.. code-block:: python + + from pytorch_lightning.profiler import Profiler, PassThroughProfiler + + class MyModel(LightningModule): + def __init__(self, hparams, profiler=None): + self.hparams = hparams + self.profiler = profiler or PassThroughProfiler() + + def custom_processing_step(self, data): + with profiler.profile('my_custom_action'): + # custom processing step + return data + + profiler = Profiler() + model = MyModel(hparams, profiler) + trainer = Trainer(profiler=profiler, max_epochs=1) + +""" + +from .profiler import Profiler, AdvancedProfiler, PassThroughProfiler + +__all__ = [ + 'Profiler', + 'AdvancedProfiler', + 'PassThroughProfiler', +] diff --git a/pytorch_lightning/profiler/profiler.py b/pytorch_lightning/profiler/profiler.py new file mode 100644 index 0000000000..32f220897a --- /dev/null +++ b/pytorch_lightning/profiler/profiler.py @@ -0,0 +1,181 @@ +from contextlib import contextmanager +from collections import defaultdict +import time +import numpy as np +import cProfile +import pstats +import io +from abc import ABC, abstractmethod +import logging + +logger = logging.getLogger(__name__) + + +class BaseProfiler(ABC): + """ + If you wish to write a custom profiler, you should inhereit from this class. + """ + + @abstractmethod + def start(self, action_name): + """ + Defines how to start recording an action. + """ + pass + + @abstractmethod + def stop(self, action_name): + """ + Defines how to record the duration once an action is complete. + """ + pass + + @contextmanager + def profile(self, action_name): + """ + Yields a context manager to encapsulate the scope of a profiled action. + + Example:: + + with self.profile('load training data'): + # load training data code + + The profiler will start once you've entered the context and will automatically + stop once you exit the code block. + """ + try: + self.start(action_name) + yield action_name + finally: + self.stop(action_name) + + def profile_iterable(self, iterable, action_name): + iterator = iter(iterable) + while True: + try: + self.start(action_name) + value = next(iterator) + self.stop(action_name) + yield value + except StopIteration: + self.stop(action_name) + break + + def describe(self): + """ + Logs a profile report after the conclusion of the training run. + """ + pass + + +class PassThroughProfiler(BaseProfiler): + """ + This class should be used when you don't want the (small) overhead of profiling. + The Trainer uses this class by default. + """ + + def __init__(self): + pass + + def start(self, action_name): + pass + + def stop(self, action_name): + pass + + +class Profiler(BaseProfiler): + """ + This profiler simply records the duration of actions (in seconds) and reports + the mean duration of each action and the total time spent over the entire training run. + """ + + def __init__(self): + self.current_actions = {} + self.recorded_durations = defaultdict(list) + + def start(self, action_name): + if action_name in self.current_actions: + raise ValueError( + f"Attempted to start {action_name} which has already started." + ) + self.current_actions[action_name] = time.monotonic() + + def stop(self, action_name): + end_time = time.monotonic() + if action_name not in self.current_actions: + raise ValueError( + f"Attempting to stop recording an action ({action_name}) which was never started." + ) + start_time = self.current_actions.pop(action_name) + duration = end_time - start_time + self.recorded_durations[action_name].append(duration) + + def describe(self): + output_string = "\n\nProfiler Report\n" + + def log_row(action, mean, total): + return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}" + + output_string += log_row("Action", "Mean duration (s)", "Total time (s)") + output_string += f"\n{'-' * 65}" + for action, durations in self.recorded_durations.items(): + output_string += log_row( + action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}", + ) + output_string += "\n" + logger.info(output_string) + + +class AdvancedProfiler(BaseProfiler): + """ + This profiler uses Python's cProfiler to record more detailed information about + time spent in each function call recorded during a given action. The output is quite + verbose and you should only use this if you want very detailed reports. + """ + + def __init__(self, output_filename=None, line_count_restriction=1.0): + """ + :param output_filename (str): optionally save profile results to file instead of printing + to std out when training is finished. + :param line_count_restriction (int|float): this can be used to limit the number of functions + reported for each action. either an integer (to select a count of lines), + or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) + """ + self.profiled_actions = {} + self.output_filename = output_filename + self.line_count_restriction = line_count_restriction + + def start(self, action_name): + if action_name not in self.profiled_actions: + self.profiled_actions[action_name] = cProfile.Profile() + self.profiled_actions[action_name].enable() + + def stop(self, action_name): + pr = self.profiled_actions.get(action_name) + if pr is None: + raise ValueError( + f"Attempting to stop recording an action ({action_name}) which was never started." + ) + pr.disable() + + def describe(self): + self.recorded_stats = {} + for action_name, pr in self.profiled_actions.items(): + s = io.StringIO() + sortby = pstats.SortKey.CUMULATIVE + ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby) + ps.print_stats(self.line_count_restriction) + self.recorded_stats[action_name] = s.getvalue() + if self.output_filename is not None: + # save to file + with open(self.output_filename, "w") as f: + for action, stats in self.recorded_stats.items(): + f.write(f"Profile stats for: {action}") + f.write(stats) + else: + # log to standard out + output_string = "\nProfiler Report\n" + for action, stats in self.recorded_stats.items(): + output_string += f"\nProfile stats for: {action}\n{stats}" + logger.info(output_string) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 82eb052a95..9f9f89b7ac 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -212,7 +212,7 @@ class TrainerEvaluationLoopMixin(ABC): # bookkeeping outputs = [] - # run training + # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] for batch_idx, batch in enumerate(dataloader): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6a758f7cca..8871b2eae8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.profiler import Profiler, PassThroughProfiler + try: from apex import amp @@ -87,6 +89,7 @@ class Trainer(TrainerIOMixin, num_sanity_val_steps=5, truncated_bptt_steps=None, resume_from_checkpoint=None, + profiler=None ): r""" @@ -460,6 +463,25 @@ class Trainer(TrainerIOMixin, # resume from a specific checkpoint trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt') + profiler (BaseProfiler): To profile individual steps during training and assist in + identifying bottlenecks. + Example:: + + from pytorch_lightning.profiler import Profiler, AdvancedProfiler + + # default used by the Trainer + trainer = Trainer(profiler=None) + + # to profile standard training events + trainer = Trainer(profiler=True) + + # equivalent to profiler=True + profiler = Profiler() + trainer = Trainer(profiler=profiler) + + # advanced profiler for function-level stats + profiler = AdvancedProfiler() + trainer = Trainer(profiler=profiler) .. warning:: Following arguments become deprecated and they will be removed in v0.8.0: @@ -564,6 +586,11 @@ class Trainer(TrainerIOMixin, # configure logger self.configure_logger(logger) + # configure profiler + if profiler is True: + profiler = Profiler() + self.profiler = profiler or PassThroughProfiler() + # configure early stop callback # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) @@ -870,6 +897,9 @@ class Trainer(TrainerIOMixin, # CORE TRAINING LOOP self.train() + # summarize profile results + self.profiler.describe() + def test(self, model=None): r""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a1a5676494..3e9a906016 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -211,6 +211,7 @@ class TrainerTrainLoopMixin(ABC): self.training_tqdm_dict = None self.get_train_dataloader = None self.reduce_lr_on_plateau_scheduler = None + self.profiler = None @property def max_nb_epochs(self): @@ -357,12 +358,14 @@ class TrainerTrainLoopMixin(ABC): stop = should_stop and met_min_epochs if stop: self.main_progress_bar.close() - model.on_train_end() + with self.profiler.profile('on_train_end'): + model.on_train_end() return self.main_progress_bar.close() - model.on_train_end() + with self.profiler.profile('on_train_end'): + model.on_train_end() if self.logger is not None: self.logger.finalize("success") @@ -371,10 +374,13 @@ class TrainerTrainLoopMixin(ABC): # before epoch hook if self.is_function_implemented('on_epoch_start'): model = self.get_model() - model.on_epoch_start() + with self.profiler.profile('on_epoch_start'): + model.on_epoch_start() # run epoch - for batch_idx, batch in enumerate(self.get_train_dataloader()): + for batch_idx, batch in self.profiler.profile_iterable( + enumerate(self.get_train_dataloader()), "get_train_batch" + ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: break @@ -432,7 +438,8 @@ class TrainerTrainLoopMixin(ABC): # epoch end hook if self.is_function_implemented('on_epoch_end'): model = self.get_model() - model.on_epoch_end() + with self.profiler.profile('on_epoch_end'): + model.on_epoch_end() def run_training_batch(self, batch, batch_idx): # track grad norms @@ -450,7 +457,8 @@ class TrainerTrainLoopMixin(ABC): # hook if self.is_function_implemented('on_batch_start'): model_ref = self.get_model() - response = model_ref.on_batch_start(batch) + with self.profiler.profile('on_batch_start'): + response = model_ref.on_batch_start(batch) if response == -1: return -1, grad_norm_dic, {} @@ -458,7 +466,8 @@ class TrainerTrainLoopMixin(ABC): splits = [batch] if self.truncated_bptt_steps is not None: model_ref = self.get_model() - splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) + with self.profiler.profile('tbptt_split_batch'): + splits = model_ref.tbptt_split_batch(batch, self.truncated_bptt_steps) self.hiddens = None for split_idx, split_batch in enumerate(splits): @@ -478,8 +487,9 @@ class TrainerTrainLoopMixin(ABC): # wrap the forward step in a closure so second order methods work def optimizer_closure(): # forward pass - output = self.training_forward( - split_batch, batch_idx, opt_idx, self.hiddens) + with self.profiler.profile('model_forward'): + output = self.training_forward( + split_batch, batch_idx, opt_idx, self.hiddens) closure_loss = output[0] progress_bar_metrics = output[1] @@ -493,7 +503,8 @@ class TrainerTrainLoopMixin(ABC): # backward pass model_ref = self.get_model() - model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx) + with self.profiler.profile('model_backward'): + model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx) # track metrics for callbacks all_callback_metrics.append(callback_metrics) @@ -505,7 +516,8 @@ class TrainerTrainLoopMixin(ABC): # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() - model_ref.on_after_backward() + with self.profiler.profile('on_after_backward'): + model_ref.on_after_backward() return closure_loss @@ -535,8 +547,9 @@ class TrainerTrainLoopMixin(ABC): # calls .step(), .zero_grad() # override function to modify this behavior model = self.get_model() - model.optimizer_step(self.current_epoch, batch_idx, - optimizer, opt_idx, optimizer_closure) + with self.profiler.profile('optimizer_step'): + model.optimizer_step(self.current_epoch, batch_idx, + optimizer, opt_idx, optimizer_closure) # calculate running loss for display self.running_loss.append(self.batch_loss_value) @@ -546,7 +559,8 @@ class TrainerTrainLoopMixin(ABC): # activate batch end hook if self.is_function_implemented('on_batch_end'): model = self.get_model() - model.on_batch_end() + with self.profiler.profile('on_batch_end'): + model.on_batch_end() # update progress bar self.main_progress_bar.update(1) @@ -606,7 +620,8 @@ class TrainerTrainLoopMixin(ABC): # allow any mode to define training_end if self.is_overriden('training_end'): model_ref = self.get_model() - output = model_ref.training_end(output) + with self.profiler.profile('training_end'): + output = model_ref.training_end(output) # format and reduce outputs accordingly output = self.process_output(output, train=True) diff --git a/tests/test_profiler.py b/tests/test_profiler.py new file mode 100644 index 0000000000..d6e085a55e --- /dev/null +++ b/tests/test_profiler.py @@ -0,0 +1,50 @@ +from pytorch_lightning.profiler import Profiler, AdvancedProfiler +import time +import numpy as np + + +def test_simple_profiler(): + p = Profiler() + + with p.profile("a"): + time.sleep(3) + + with p.profile("a"): + time.sleep(1) + + with p.profile("b"): + time.sleep(2) + + with p.profile("c"): + time.sleep(1) + + # different environments have different precision when it comes to time.sleep() + np.testing.assert_almost_equal(p.recorded_durations["a"], [3, 1], decimal=1) + np.testing.assert_almost_equal(p.recorded_durations["b"], [2], decimal=1) + np.testing.assert_almost_equal(p.recorded_durations["c"], [1], decimal=1) + + +def test_advanced_profiler(): + def get_duration(profile): + return sum([x.totaltime for x in profile.getstats()]) + + p = AdvancedProfiler() + + with p.profile("a"): + time.sleep(3) + + with p.profile("a"): + time.sleep(1) + + with p.profile("b"): + time.sleep(2) + + with p.profile("c"): + time.sleep(1) + + a_duration = get_duration(p.profiled_actions["a"]) + np.testing.assert_almost_equal(a_duration, [4], decimal=1) + b_duration = get_duration(p.profiled_actions["b"]) + np.testing.assert_almost_equal(b_duration, [2], decimal=1) + c_duration = get_duration(p.profiled_actions["c"]) + np.testing.assert_almost_equal(c_duration, [1], decimal=1)