[feat] Add PyTorch Profiler. (#5560)

* add profiler

* add profiler

* update

* resolve flake8

* update doc

* update changelog

* clean doc

* delete prof file

* merge pr codebase

* update

* update doc

* update doc

* update doc

* update on comments

* update docstring

* update docstring

* try

* update test

* Update pytorch_lightning/profiler/__init__.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/profiler/__init__.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update on comments

* remove old code

* add support for ddp

* resolve flake8

* Update pytorch_lightning/profiler/__init__.py

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

* resolve tests

* resolve flake8

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
chaton 2021-01-26 11:48:54 +00:00 committed by GitHub
parent f782230412
commit 5f3372871a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 500 additions and 13 deletions

1
.gitignore vendored
View File

@ -141,3 +141,4 @@ pytorch\ lightning
test-reports/
wandb
.forked/
*.prof

View File

@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))
- Added `PyTorchProfiler` ([#5560](https://github.com/PyTorchLightning/pytorch-lightning/pull/5560))
### Changed

View File

@ -16,7 +16,7 @@ import os
import shutil
import subprocess
from collections import OrderedDict
from typing import Tuple, Dict, Union, List, Any
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import torch
@ -182,7 +182,8 @@ class ModelSummary(object):
self._model = model
self._mode = mode
self._layer_summary = self.summarize()
self._precision_megabytes = (self._model.precision / 8.0) * 1e-6 # 1 byte -> 8 bits
# 1 byte -> 8 bits
self._precision_megabytes = (self._model.precision / 8.0) * 1e-6
@property
def named_modules(self) -> List[Tuple[str, nn.Module]]:

View File

@ -50,7 +50,7 @@ The profiler's results will be printed at the completion of a training `fit()`.
Advanced Profiling
--------------------
------------------
If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
@ -114,13 +114,98 @@ to track and the profiler will record performance for code executed within this
model = MyModel(profiler)
trainer = Trainer(profiler=profiler, max_epochs=1)
PyTorch Profiling
-----------------
Autograd includes a profiler that lets you inspect the cost of different operators
inside your model - both on the CPU and GPU.
Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html)
.. code-block:: python
trainer = Trainer(..., profiler="pytorch")
or
profiler = PyTorchProfiler(...)
trainer = Trainer(..., profiler=profiler)
This profiler works with PyTorch ``DistributedDataParallel``.
If ``output_filename`` is provided, each rank will save their profiled operation to their own file.
The profiler's results will be printed on the completion of a training `fit()`. This profiler
report can be quite long, so you can also specify an `output_filename` to save the report instead
of logging it to the output in your terminal.
This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default.
The output below shows the profiling for the action `training_step_and_backward`.
The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions.
.. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501
.. code-block:: python
Profiler Report
Profile stats for: training_step_and_backward
--------------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg
--------------------- --------------- --------------- --------------- --------------- ---------------
t 62.10% 1.044ms 62.77% 1.055ms 1.055ms
addmm 32.32% 543.135us 32.69% 549.362us 549.362us
mse_loss 1.35% 22.657us 3.58% 60.105us 60.105us
mean 0.22% 3.694us 2.05% 34.523us 34.523us
div_ 0.64% 10.756us 1.90% 32.001us 16.000us
ones_like 0.21% 3.461us 0.81% 13.669us 13.669us
sum_out 0.45% 7.638us 0.74% 12.432us 12.432us
transpose 0.23% 3.786us 0.68% 11.393us 11.393us
as_strided 0.60% 10.060us 0.60% 10.060us 3.353us
to 0.18% 3.059us 0.44% 7.464us 7.464us
empty_like 0.14% 2.387us 0.41% 6.859us 6.859us
empty_strided 0.38% 6.351us 0.38% 6.351us 3.175us
fill_ 0.28% 4.782us 0.33% 5.566us 2.783us
expand 0.20% 3.336us 0.28% 4.743us 4.743us
empty 0.27% 4.456us 0.27% 4.456us 2.228us
copy_ 0.15% 2.526us 0.15% 2.526us 2.526us
broadcast_tensors 0.15% 2.492us 0.15% 2.492us 2.492us
size 0.06% 0.967us 0.06% 0.967us 0.484us
is_complex 0.06% 0.961us 0.06% 0.961us 0.481us
stride 0.03% 0.517us 0.03% 0.517us 0.517us
--------------------- --------------- --------------- --------------- --------------- ---------------
Self CPU time total: 1.681ms
When running with `PyTorchProfiler(emit_nvtx=True)`. You should run as following::
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
To visualize the profiled operation, you can either:
* Use::
nvvp trace_name.prof
* Use::
python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))'
"""
from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.profiler.profilers import (
AdvancedProfiler,
BaseProfiler,
PassThroughProfiler,
PyTorchProfiler,
SimpleProfiler,
)
__all__ = [
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'PassThroughProfiler',
"PyTorchProfiler",
]

View File

@ -15,6 +15,7 @@
"""Profiler to check if there are any bottlenecks in your code."""
import cProfile
import inspect
import io
import os
import pstats
@ -22,12 +23,16 @@ import time
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Optional, Union
from typing import List, Optional, Union
import numpy as np
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class BaseProfiler(ABC):
@ -95,6 +100,9 @@ class BaseProfiler(ABC):
def summary(self) -> str:
"""Create profiler summary in text format."""
def on_train_start(self, local_rank: Optional[int] = None):
self.local_rank = local_rank
class PassThroughProfiler(BaseProfiler):
"""
@ -282,3 +290,263 @@ class AdvancedProfiler(BaseProfiler):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
class PyTorchProfiler(BaseProfiler):
PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step")
AVAILABLE_SORT_KEYS = (
"cpu_time", "cuda_time", "cpu_time_total",
"cuda_time_total", "cpu_memory_usage", "cuda_memory_usage",
"self_cpu_memory_usage", "self_cuda_memory_usage", "count"
)
def __init__(
self,
output_filename: Optional[str] = None,
enabled: bool = True,
use_cuda: bool = False,
record_shapes: bool = False,
profile_memory: bool = False,
group_by_input_shapes: bool = False,
with_stack: bool = False,
use_kineto: bool = False,
use_cpu: bool = False,
emit_nvtx: bool = False,
export_to_chrome: bool = False,
path_to_export_trace: str = None,
row_limit: int = 20,
sort_by_key: Optional[str] = None,
profiled_functions: Optional[List] = None,
local_rank: Optional[int] = None,
):
"""
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of
different operators inside your model - both on the CPU and GPU
Args:
output_filename: optionally save profile results to file instead of printing
to std out when training is finished. When using ``ddp``,
each rank will stream the profiled operation to their own file
with the extension ``_{rank}.txt``
enabled: Setting this to False makes this context manager a no-op.
use_cuda: Enables timing of CUDA events as well using the cudaEvent API.
Adds approximately 4us of overhead to each tensor operation.
record_shapes: If shapes recording is set, information about input dimensions will be collected.
profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0)
group_by_input_shapes: Include operator input shapes and group calls by shape.
with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0)
use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0)
use_cpu: use_kineto=True and can be used to lower the overhead
for GPU-only profiling (Introduced in PyTorch 1.8.0)
emit_nvtx: Context manager that makes every autograd operation emit an NVTX range
Run::
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
To visualize, you can either use::
nvvp trace_name.prof
torch.autograd.profiler.load_nvprof(path)
export_to_chrome: Wether to export the sequence of profiled operators for Chrome.
It will generate a ``.json`` file which can be read by Chrome.
path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``.
By default, it will be save where the file being is being run.
row_limit: Limit the number of rows in a table, `0` is a special value that
removes the limit completely.
sort_by_key: Keys to sort out profiled table
profiled_functions: list of profiled functions which will create a context manager on.
Any other will be pass through.
local_rank: When running in distributed setting, local_rank is used for each process
to write to their own file if `output_fname` is provided.
"""
self.profiled_actions = {}
self.enabled = enabled
self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS
self.use_cuda = use_cuda
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total")
self.with_stack = with_stack
self.group_by_input_shapes = group_by_input_shapes and record_shapes
self.use_kineto = use_kineto
self.use_cpu = use_cpu
self.row_limit = row_limit
self.emit_nvtx = emit_nvtx
self.export_to_chrome = export_to_chrome
self.path_to_export_trace = path_to_export_trace
if export_to_chrome and path_to_export_trace is None:
rank_zero_warn(
"The exported trace would be save locally as `path_to_export_trace` is empty."
" Note: Each functions will generate its own traced file.")
if self.sort_by_key not in self.AVAILABLE_SORT_KEYS:
raise MisconfigurationException(
f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. ")
self.profiled_actions = {}
self.context_names = {}
self.running_stack = []
self.profiler = None
self.output_fname = output_filename
self.output_file = None
if local_rank is not None:
self.on_train_start(local_rank=local_rank)
self.on_train_start = super().on_train_start
def on_train_start(self, local_rank: Optional[str] = None):
self.local_rank = local_rank
# when logging to `log.info`, only perform profiling on rank 0
if local_rank != 0 and self.output_fname is None:
self.wrap_functions_into_rank_zero_only()
if self.output_fname:
if local_rank is not None:
if '.txt' not in self.output_fname:
raise MisconfigurationException("Log file should be .txt file.")
self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt")
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")
streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)
def wrap_functions_into_rank_zero_only(self):
self.start = rank_zero_only(self.start)
self.stop = rank_zero_only(self.stop)
self.summary = rank_zero_only(self.summary)
self.describe = rank_zero_only(self.describe)
def start(self, action_name: str) -> None:
if action_name not in self.profiled_functions:
return
if len(self.running_stack) > 0:
self._stop(self.running_stack[-1])
self.running_stack.append(action_name)
self.context_names[action_name] = "/".join(self.running_stack)
self._start(action_name)
def _start(self, action_name: str) -> None:
if self.emit_nvtx:
self._create_profiler(action_name, torch.cuda.profiler.profile, enter=False)
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
else:
self._create_profiler(action_name, torch.autograd.profiler.profile)
def _create_profiler(self, action_name, profiler, enter=True):
init_args = inspect.signature(profiler.__init__).parameters
profiler_args = {
k: v for k, v in vars(self).items() if k in init_args
}
pr = profiler(**profiler_args)
if enter:
pr = pr.__enter__()
self.profiler = pr
def _stop(self, action_name: str) -> None:
if self.profiler is None:
return
self.profiler.__exit__(
exc_type=None,
exc_val=None,
exc_tb=None
)
function_events = self.profiler.function_events
self.profiler = None
for name in self.running_stack:
if name not in self.profiled_actions:
self.profiled_actions[name] = function_events
else:
self.profiled_actions[name] += function_events
def stop(self, action_name: str) -> None:
if action_name not in self.profiled_functions:
return
if len(self.running_stack) == 0 or self.running_stack[-1] != action_name:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
self._stop(action_name)
self.running_stack.pop()
# restore running profiler
if len(self.running_stack) > 0:
self._start(self.running_stack[-1])
def summary(self) -> str:
recorded_stats = {}
output_string = ''
local_rank = '0' if self.local_rank is None else self.local_rank
if not self.enabled:
return output_string
for action_name, function_events in self.profiled_actions.items():
# next line is a workaround for a pytorch issue (fixed on master, still present
# on 1.7). Without it the code fails with `AssertionError: There is already a CPU
# parent event for detach`
function_events.populate_cpu_children = lambda: None
if self.export_to_chrome:
filename = f"{action_name}_{local_rank}_trace.json"
path_to_trace = filename if self.path_to_export_trace is None \
else os.path.join(self.path_to_export_trace, filename)
function_events.export_chrome_trace(path_to_trace)
if self.emit_nvtx:
return output_string
else:
table = function_events.key_averages(
group_by_input_shapes=self.group_by_input_shapes).table(
sort_by=self.sort_by_key,
row_limit=self.row_limit)
recorded_stats[action_name] = table
# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += (
f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}"
)
return output_string
def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()
def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

View File

@ -14,13 +14,20 @@
from typing import Union
from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.profiler import (
AdvancedProfiler,
BaseProfiler,
PassThroughProfiler,
PyTorchProfiler,
SimpleProfiler,
)
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
PROFILERS = {
"simple": SimpleProfiler,
"advanced": AdvancedProfiler,
"pytorch": PyTorchProfiler
}
@ -51,3 +58,7 @@ class ProfilerConnector:
raise ValueError("When passing string value for the `profiler` parameter of"
" `Trainer`, it can only be 'simple' or 'advanced'")
self.trainer.profiler = profiler or PassThroughProfiler()
def on_train_start(self, trainer):
local_rank = trainer.local_rank if trainer.world_size > 1 else None
self.trainer.profiler.on_train_start(local_rank)

View File

@ -164,10 +164,12 @@ class EvaluationLoop(object):
# run actual test step
if self.testing:
model_ref._current_fx_name = "test_step"
output = self.trainer.accelerator_backend.test_step(args)
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator_backend.test_step(args)
else:
model_ref._current_fx_name = "validation_step"
output = self.trainer.accelerator_backend.validation_step(args)
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator_backend.validation_step(args)
# capture any logged information
self.trainer.logger_connector.cache_logged_metrics()

View File

@ -111,6 +111,9 @@ class TrainLoop:
# hook
self.trainer.call_hook("on_train_start")
# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
# bind logger and other properties
self.trainer.model_connector.copy_trainer_model_properties(model)
@ -339,7 +342,8 @@ class TrainLoop:
# manually capture logged metrics
model_ref._current_fx_name = 'training_step'
model_ref._results = Result()
training_step_output = self.trainer.accelerator_backend.training_step(args)
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.logger_connector.cache_logged_metrics()
self._check_training_step_output(training_step_output)

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from abc import ABC
from collections import OrderedDict

View File

@ -40,8 +40,10 @@ class PreCalculatedModel(BoringModel):
def __init__(self, precision: int = 32):
super().__init__()
self.layer = nn.Linear(32, 1000, bias=False) # 32K params
self.layer1 = nn.Linear(1000, 218, bias=False) # 218K params
# 32K params
self.layer = nn.Linear(32, 1000, bias=False)
# 218K params
self.layer1 = nn.Linear(1000, 218, bias=False)
# calculate model size based on precision.
self.pre_calculated_model_size = 1.0 / (32 / precision)

View File

@ -23,3 +23,4 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_process/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp

View File

@ -740,6 +740,7 @@ def test_logging_sync_dist_true_ddp(tmpdir):
weights_summary=None,
accelerator="ddp",
gpus=2,
profiler="pytorch"
)
trainer.fit(model)

View File

@ -17,6 +17,7 @@ import pickle
import sys
from argparse import Namespace
from copy import deepcopy
from distutils.version import LooseVersion
from pathlib import Path
from unittest.mock import ANY, call, patch
@ -30,7 +31,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
@ -39,6 +40,12 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate
@pytest.fixture
def pytorch_profiler(tmpdir):
profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0)
return profiler
@pytest.mark.parametrize("url_ckpt", [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
@ -1421,6 +1428,7 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l
('simple', SimpleProfiler),
('Simple', SimpleProfiler),
('advanced', AdvancedProfiler),
('pytorch', PyTorchProfiler),
])
def test_trainer_profiler_correct_args(profiler, expected):
kwargs = {'profiler': profiler} if profiler is not None else {}
@ -1441,3 +1449,105 @@ def test_trainer_profiler_incorrect_arg_type(profiler):
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
r" are valid values for `Trainer`'s `profiler` parameter. *"):
Trainer(profiler=profiler)
def test_pytorch_profiler_describe(pytorch_profiler):
"""Ensure the profiler won't fail when reporting the summary."""
with pytorch_profiler.profile("test_step"):
pass
# log to stdout and print to file
pytorch_profiler.describe()
data = Path(pytorch_profiler.output_fname).read_text()
assert len(data) > 0
def test_pytorch_profiler_value_errors(pytorch_profiler):
"""Ensure errors are raised where expected."""
action = "test_step"
with pytest.raises(ValueError):
pytorch_profiler.stop(action)
pytorch_profiler.start(action)
pytorch_profiler.stop(action)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@pytest.mark.parametrize("use_output_filename", [False, True])
def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename):
"""Ensure that the profiler can be given to the training and default step are properly recorded. """
if use_output_filename:
output_filename = os.path.join(tmpdir, "profiler.txt")
else:
output_filename = None
profiler = PyTorchProfiler(output_filename=output_filename)
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
profiler=profiler,
accelerator="ddp",
gpus=2
)
trainer.fit(model)
enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0
if enabled:
assert len(profiler.summary()) > 0
assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'}
else:
assert profiler.summary() is None
assert set(profiler.profiled_actions.keys()) == set()
if use_output_filename:
profiler.describe()
data = Path(profiler.output_fname).read_text()
assert len(data) > 0
def test_pytorch_profiler_nested(tmpdir):
"""Ensure that the profiler handles nested context"""
pytorch_profiler = PyTorchProfiler(
profiled_functions=["a", "b", "c"],
use_cuda=False,
output_filename=os.path.join(tmpdir, "profiler.txt"))
with pytorch_profiler.profile("a"):
a = torch.ones(42)
with pytorch_profiler.profile("b"):
b = torch.zeros(42)
with pytorch_profiler.profile("c"):
_ = a + b
pa = pytorch_profiler.profiled_actions
# From PyTorch 1.6.0, more operation are being traced.
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
prefix_to_remove = "aten::" if LooseVersion(torch.__version__) >= LooseVersion("1.7.1") else ''
expected_a = ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty']
assert [e.name.replace(prefix_to_remove, '') for e in pa['a']] == expected_a
expected_b = ['zeros', 'empty', 'zero_', 'fill_']
assert [e.name.replace(prefix_to_remove, '') for e in pa['b']] == expected_b
expected_c = ['add', 'empty']
assert [e.name.replace(prefix_to_remove, '') for e in pa['c']] == expected_c
else:
expected_a = ['add']
assert [e.name for e in pa['a']] == expected_a
expected_b = []
assert [e.name for e in pa['b']] == expected_b
expected_c = ['add']
assert [e.name for e in pa['c']] == expected_c