XLA Profiler integration (#8014)

This commit is contained in:
Kaushik B 2021-06-29 00:58:05 +05:30 committed by GitHub
parent c521624a92
commit 2f3c65e57b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 194 additions and 2 deletions

View File

@ -91,7 +91,7 @@ jobs:
docker:
- image: circleci/python:3.7
environment:
- XLA_VER: 1.7
- XLA_VER: 1.8
- MAX_CHECKS: 240
- CHECK_SPEEP: 5
steps:

View File

@ -118,6 +118,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Add torchelastic check when sanitizing GPUs ([#8095](https://github.com/PyTorchLightning/pytorch-lightning/pull/8095))
- Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014))
### Changed

View File

@ -22,6 +22,7 @@ local tputests = base.BaseTest {
|||
cd pytorch-lightning
coverage run --source=pytorch_lightning -m pytest -v --capture=no \
tests/profiler/test_xla_profiler.py \
pytorch_lightning/utilities/xla_device.py \
tests/accelerators/test_tpu_backend.py \
tests/models/test_tpu.py

View File

@ -198,6 +198,7 @@ from pytorch_lightning.profiler.advanced import AdvancedProfiler
from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler
from pytorch_lightning.profiler.pytorch import PyTorchProfiler
from pytorch_lightning.profiler.simple import SimpleProfiler
from pytorch_lightning.profiler.xla import XLAProfiler
__all__ = [
'AbstractProfiler',
@ -206,4 +207,5 @@ __all__ = [
'PassThroughProfiler',
'PyTorchProfiler',
'SimpleProfiler',
'XLAProfiler',
]

View File

@ -9,6 +9,7 @@ from pytorch_lightning.profiler.advanced import AdvancedProfiler # noqa E402
from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler # noqa E402
from pytorch_lightning.profiler.pytorch import PyTorchProfiler # noqa E402
from pytorch_lightning.profiler.simple import SimpleProfiler # noqa E402
from pytorch_lightning.profiler.xla import XLAProfiler # noqa E402
__all__ = [
'AbstractProfiler',
@ -17,4 +18,5 @@ __all__ = [
'PassThroughProfiler',
'PyTorchProfiler',
'SimpleProfiler',
'XLAProfiler',
]

View File

@ -0,0 +1,110 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
XLA Profiler will help you debug and optimize training workload performance
for your models using Cloud TPU performance tools.
Manual capture via TensorBoard
The following instructions are for capturing trace from a running program
0. This [guide](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm#tpu-vm) will
help you with the Cloud TPU setup with the required installations
1. Start a TensorBoard Server
>> tensorboard --logdir ./tensorboard --port 9001
You could view the TensorBoard output at http://localhost:9001 on your local machine, and then open the
``PROFILE`` plugn from the top right dropdown or open http://localhost:9001/#profile
2. Once the code you'd like to profile is running, click on ``CAPTURE PROFILE`` button. You could enter
``localhost:9012`` (default port for XLA Profiler) as the Profile Service URL. Then, you could enter
the number of milliseconds for the profiling duration, and click ``CAPTURE``
3. Make sure the code is running, while you are trying to capture the traces. Also, it would lead to better
performance insights if the profiling duration is longer than the step time
4. Once the capture is finished, the page will refresh and you could browse through the insights using the
``Tools`` dropdown at the top left
"""
import logging
from typing import Dict
from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.utilities import _TPU_AVAILABLE
if _TPU_AVAILABLE:
import torch_xla.debug.profiler as xp
log = logging.getLogger(__name__)
class XLAProfiler(BaseProfiler):
STEP_FUNCTIONS = {
"training_step_and_backward",
"validation_step",
"test_step",
"predict_step",
}
RECORD_FUNCTIONS = {
"training_step_and_backward",
"training_step",
"backward",
"validation_step",
"test_step",
"predict_step",
}
def __init__(self, port: int = 9012) -> None:
"""
This Profiler will help you debug and optimize training workload performance
for your models using Cloud TPU performance tools.
"""
super().__init__(dirpath=None, filename=None, output_filename=None)
self.port = port
self._recording_map: Dict = {}
self._step_recoding_map: Dict = {}
self._start_trace: bool = False
def start(self, action_name: str) -> None:
if action_name in self.RECORD_FUNCTIONS:
if not self._start_trace:
self.server = xp.start_server(self.port)
self._start_trace = True
if action_name in self.STEP_FUNCTIONS:
step = self._get_step_num(action_name)
recording = xp.StepTrace(action_name, step_num=step)
else:
recording = xp.Trace(action_name)
recording.__enter__()
self._recording_map[action_name] = recording
def stop(self, action_name: str) -> None:
if action_name in self._recording_map:
self._recording_map[action_name].__exit__(None, None, None)
del self._recording_map[action_name]
def _get_step_num(self, action_name: str) -> int:
if action_name not in self._step_recoding_map:
self._step_recoding_map[action_name] = 1
else:
self._step_recoding_map[action_name] += 1
return self._step_recoding_map[action_name]
def summary(self) -> str:
return ""

View File

@ -39,6 +39,7 @@ from pytorch_lightning.profiler import (
PassThroughProfiler,
PyTorchProfiler,
SimpleProfiler,
XLAProfiler,
)
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
@ -1183,6 +1184,7 @@ class Trainer(
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PyTorchProfiler,
"xla": XLAProfiler,
}
profiler = profiler.lower()
if profiler not in PROFILERS:

View File

View File

@ -0,0 +1,72 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from multiprocessing import Event, Process
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.profiler import XLAProfiler
from pytorch_lightning.utilities import _TPU_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if _TPU_AVAILABLE:
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu
@RunIf(tpu=True)
def test_xla_profiler_instance(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler="xla",
tpu_cores=8,
)
assert isinstance(trainer.profiler, XLAProfiler)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
@pytest.mark.skipif(True, reason="XLA Profiler doesn't support Prog. capture yet")
def test_xla_profiler_prog_capture(tmpdir):
port = xu.get_free_tcp_ports()[0]
training_started = Event()
def train_worker():
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=4,
profiler="xla",
tpu_cores=8,
)
trainer.fit(model)
p = Process(target=train_worker, daemon=True)
p.start()
training_started.wait(120)
logdir = str(tmpdir)
xp.trace(f'localhost:{port}', logdir, duration_ms=2000, num_tracing_attempts=5, delay_ms=1000)
p.terminate()
assert os.isfile(os.path.join(logdir, 'plugins', 'profile', '*', '*.xplane.pb'))

View File

@ -69,7 +69,7 @@ for i in "${!files_arr[@]}"; do
done
if nvcc --version; then
nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx
nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} tests/profiler/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx
fi
# needs to run outside of `pytest`