diff --git a/docs/source/advanced/profiler.rst b/docs/source/advanced/profiler.rst index eccba79978..53171d3695 100644 --- a/docs/source/advanced/profiler.rst +++ b/docs/source/advanced/profiler.rst @@ -208,6 +208,50 @@ To visualize the profiled operation, you can either: python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' +XLA Profiler +============ + +:class:`~pytorch_lightning.profiler.xla.XLAProfiler` will help you debug and optimize training +workload performance for your models using Cloud TPU performance tools. + +.. code-block:: python + + # by passing the `XLAProfiler` alias + trainer = Trainer(..., profiler="xla") + + # or by passing an instance + from pytorch_lightning.profiler import XLAProfiler + + profiler = XLAProfiler(port=9001) + trainer = Trainer(..., profiler=profiler) + + +Manual Capture via TensorBoard +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following instructions are for capturing traces from a running program: + +0. This `guide `_ will +help you with the Cloud TPU setup with the required installations. + +1. Start a `TensorBoard `_ server. You could view the TensorBoard output at ``http://localhost:9001`` on your local machine, and then open the +``PROFILE`` plugin from the top right dropdown or open ``http://localhost:9001/#profile`` + +.. code-block:: bash + + tensorboard --logdir ./tensorboard --port 9001 + +2. Once the code you'd like to profile is running, click on the ``CAPTURE PROFILE`` button. Enter +``localhost:9001`` (default port for XLA Profiler) as the Profile Service URL. Then, enter +the number of milliseconds for the profiling duration, and click ``CAPTURE`` + +3. Make sure the code is running while you are trying to capture the traces. Also, it would lead to better +performance insights if the profiling duration is longer than the step time. + +4. Once the capture is finished, the page will refresh and you can browse through the insights using the +``Tools`` dropdown at the top left + + ---------------- **************** diff --git a/pytorch_lightning/profiler/xla.py b/pytorch_lightning/profiler/xla.py index be158f7be4..120b858206 100644 --- a/pytorch_lightning/profiler/xla.py +++ b/pytorch_lightning/profiler/xla.py @@ -11,33 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU -performance tools. - -Manual capture via TensorBoard - -The following instructions are for capturing trace from a running program - -0. This [guide](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm#tpu-vm) will -help you with the Cloud TPU setup with the required installations - -1. Start a TensorBoard Server - ->> tensorboard --logdir ./tensorboard --port 9001 - -You could view the TensorBoard output at http://localhost:9001 on your local machine, and then open the -``PROFILE`` plugin from the top right dropdown or open http://localhost:9001/#profile - -2. Once the code you'd like to profile is running, click on ``CAPTURE PROFILE`` button. You could enter -``localhost:9012`` (default port for XLA Profiler) as the Profile Service URL. Then, you could enter -the number of milliseconds for the profiling duration, and click ``CAPTURE`` - -3. Make sure the code is running, while you are trying to capture the traces. Also, it would lead to better -performance insights if the profiling duration is longer than the step time - -4. Once the capture is finished, the page will refresh and you could browse through the insights using the -``Tools`` dropdown at the top left -""" import logging from typing import Dict @@ -63,8 +36,13 @@ class XLAProfiler(BaseProfiler): } def __init__(self, port: int = 9012) -> None: - """This Profiler will help you debug and optimize training workload performance for your models using Cloud - TPU performance tools.""" + """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud + TPU performance tools. + + Args: + port: the port to start the profiler server on. An exception is + raised if the provided port is invalid or busy. + """ if not _TPU_AVAILABLE: raise MisconfigurationException("`XLAProfiler` is only supported on TPUs") if not _TORCH_GREATER_EQUAL_1_8: