[feat] Add PyTorch Profiler. (#5560)
* add profiler * add profiler * update * resolve flake8 * update doc * update changelog * clean doc * delete prof file * merge pr codebase * update * update doc * update doc * update doc * update on comments * update docstring * update docstring * try * update test * Update pytorch_lightning/profiler/__init__.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/profiler/__init__.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update on comments * remove old code * add support for ddp * resolve flake8 * Update pytorch_lightning/profiler/__init__.py Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> * resolve tests * resolve flake8 Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
parent
f782230412
commit
5f3372871a
|
@ -141,3 +141,4 @@ pytorch\ lightning
|
|||
test-reports/
|
||||
wandb
|
||||
.forked/
|
||||
*.prof
|
||||
|
|
|
@ -59,6 +59,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))
|
||||
|
||||
|
||||
- Added `PyTorchProfiler` ([#5560](https://github.com/PyTorchLightning/pytorch-lightning/pull/5560))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ import os
|
|||
import shutil
|
||||
import subprocess
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Dict, Union, List, Any
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -182,7 +182,8 @@ class ModelSummary(object):
|
|||
self._model = model
|
||||
self._mode = mode
|
||||
self._layer_summary = self.summarize()
|
||||
self._precision_megabytes = (self._model.precision / 8.0) * 1e-6 # 1 byte -> 8 bits
|
||||
# 1 byte -> 8 bits
|
||||
self._precision_megabytes = (self._model.precision / 8.0) * 1e-6
|
||||
|
||||
@property
|
||||
def named_modules(self) -> List[Tuple[str, nn.Module]]:
|
||||
|
|
|
@ -50,7 +50,7 @@ The profiler's results will be printed at the completion of a training `fit()`.
|
|||
|
||||
|
||||
Advanced Profiling
|
||||
--------------------
|
||||
------------------
|
||||
|
||||
If you want more information on the functions called during each event, you can use the `AdvancedProfiler`.
|
||||
This option uses Python's cProfiler_ to provide a report of time spent on *each* function called within your code.
|
||||
|
@ -114,13 +114,98 @@ to track and the profiler will record performance for code executed within this
|
|||
model = MyModel(profiler)
|
||||
trainer = Trainer(profiler=profiler, max_epochs=1)
|
||||
|
||||
|
||||
PyTorch Profiling
|
||||
-----------------
|
||||
|
||||
Autograd includes a profiler that lets you inspect the cost of different operators
|
||||
inside your model - both on the CPU and GPU.
|
||||
|
||||
Find the Pytorch Profiler doc at [PyTorch Profiler](https://pytorch-lightning.readthedocs.io/en/stable/profiler.html)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(..., profiler="pytorch")
|
||||
|
||||
or
|
||||
|
||||
profiler = PyTorchProfiler(...)
|
||||
trainer = Trainer(..., profiler=profiler)
|
||||
|
||||
|
||||
This profiler works with PyTorch ``DistributedDataParallel``.
|
||||
If ``output_filename`` is provided, each rank will save their profiled operation to their own file.
|
||||
|
||||
|
||||
The profiler's results will be printed on the completion of a training `fit()`. This profiler
|
||||
report can be quite long, so you can also specify an `output_filename` to save the report instead
|
||||
of logging it to the output in your terminal.
|
||||
|
||||
This profiler will record only for `training_step_and_backward`, `evaluation_step` and `test_step` functions by default.
|
||||
The output below shows the profiling for the action `training_step_and_backward`.
|
||||
The user can provide ``PyTorchProfiler(profiled_functions=[...])`` to extend the scope of profiled functions.
|
||||
|
||||
.. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the `SimpleProfiler`. # noqa E501
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Profiler Report
|
||||
|
||||
Profile stats for: training_step_and_backward
|
||||
--------------------- --------------- --------------- --------------- --------------- ---------------
|
||||
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg
|
||||
--------------------- --------------- --------------- --------------- --------------- ---------------
|
||||
t 62.10% 1.044ms 62.77% 1.055ms 1.055ms
|
||||
addmm 32.32% 543.135us 32.69% 549.362us 549.362us
|
||||
mse_loss 1.35% 22.657us 3.58% 60.105us 60.105us
|
||||
mean 0.22% 3.694us 2.05% 34.523us 34.523us
|
||||
div_ 0.64% 10.756us 1.90% 32.001us 16.000us
|
||||
ones_like 0.21% 3.461us 0.81% 13.669us 13.669us
|
||||
sum_out 0.45% 7.638us 0.74% 12.432us 12.432us
|
||||
transpose 0.23% 3.786us 0.68% 11.393us 11.393us
|
||||
as_strided 0.60% 10.060us 0.60% 10.060us 3.353us
|
||||
to 0.18% 3.059us 0.44% 7.464us 7.464us
|
||||
empty_like 0.14% 2.387us 0.41% 6.859us 6.859us
|
||||
empty_strided 0.38% 6.351us 0.38% 6.351us 3.175us
|
||||
fill_ 0.28% 4.782us 0.33% 5.566us 2.783us
|
||||
expand 0.20% 3.336us 0.28% 4.743us 4.743us
|
||||
empty 0.27% 4.456us 0.27% 4.456us 2.228us
|
||||
copy_ 0.15% 2.526us 0.15% 2.526us 2.526us
|
||||
broadcast_tensors 0.15% 2.492us 0.15% 2.492us 2.492us
|
||||
size 0.06% 0.967us 0.06% 0.967us 0.484us
|
||||
is_complex 0.06% 0.961us 0.06% 0.961us 0.481us
|
||||
stride 0.03% 0.517us 0.03% 0.517us 0.517us
|
||||
--------------------- --------------- --------------- --------------- --------------- ---------------
|
||||
Self CPU time total: 1.681ms
|
||||
|
||||
When running with `PyTorchProfiler(emit_nvtx=True)`. You should run as following::
|
||||
|
||||
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
|
||||
|
||||
To visualize the profiled operation, you can either:
|
||||
|
||||
* Use::
|
||||
|
||||
nvvp trace_name.prof
|
||||
|
||||
* Use::
|
||||
|
||||
python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))'
|
||||
|
||||
"""
|
||||
|
||||
from pytorch_lightning.profiler.profilers import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.profiler.profilers import (
|
||||
AdvancedProfiler,
|
||||
BaseProfiler,
|
||||
PassThroughProfiler,
|
||||
PyTorchProfiler,
|
||||
SimpleProfiler,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BaseProfiler',
|
||||
'SimpleProfiler',
|
||||
'AdvancedProfiler',
|
||||
'PassThroughProfiler',
|
||||
"PyTorchProfiler",
|
||||
]
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Profiler to check if there are any bottlenecks in your code."""
|
||||
|
||||
import cProfile
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import pstats
|
||||
|
@ -22,12 +23,16 @@ import time
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class BaseProfiler(ABC):
|
||||
|
@ -95,6 +100,9 @@ class BaseProfiler(ABC):
|
|||
def summary(self) -> str:
|
||||
"""Create profiler summary in text format."""
|
||||
|
||||
def on_train_start(self, local_rank: Optional[int] = None):
|
||||
self.local_rank = local_rank
|
||||
|
||||
|
||||
class PassThroughProfiler(BaseProfiler):
|
||||
"""
|
||||
|
@ -282,3 +290,263 @@ class AdvancedProfiler(BaseProfiler):
|
|||
"""Close profiler's stream."""
|
||||
if self.output_file:
|
||||
self.output_file.close()
|
||||
|
||||
|
||||
class PyTorchProfiler(BaseProfiler):
|
||||
|
||||
PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step")
|
||||
AVAILABLE_SORT_KEYS = (
|
||||
"cpu_time", "cuda_time", "cpu_time_total",
|
||||
"cuda_time_total", "cpu_memory_usage", "cuda_memory_usage",
|
||||
"self_cpu_memory_usage", "self_cuda_memory_usage", "count"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_filename: Optional[str] = None,
|
||||
enabled: bool = True,
|
||||
use_cuda: bool = False,
|
||||
record_shapes: bool = False,
|
||||
profile_memory: bool = False,
|
||||
group_by_input_shapes: bool = False,
|
||||
with_stack: bool = False,
|
||||
use_kineto: bool = False,
|
||||
use_cpu: bool = False,
|
||||
emit_nvtx: bool = False,
|
||||
export_to_chrome: bool = False,
|
||||
path_to_export_trace: str = None,
|
||||
row_limit: int = 20,
|
||||
sort_by_key: Optional[str] = None,
|
||||
profiled_functions: Optional[List] = None,
|
||||
local_rank: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of
|
||||
different operators inside your model - both on the CPU and GPU
|
||||
|
||||
Args:
|
||||
|
||||
output_filename: optionally save profile results to file instead of printing
|
||||
to std out when training is finished. When using ``ddp``,
|
||||
each rank will stream the profiled operation to their own file
|
||||
with the extension ``_{rank}.txt``
|
||||
|
||||
enabled: Setting this to False makes this context manager a no-op.
|
||||
|
||||
use_cuda: Enables timing of CUDA events as well using the cudaEvent API.
|
||||
Adds approximately 4us of overhead to each tensor operation.
|
||||
|
||||
record_shapes: If shapes recording is set, information about input dimensions will be collected.
|
||||
|
||||
profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0)
|
||||
|
||||
group_by_input_shapes: Include operator input shapes and group calls by shape.
|
||||
|
||||
with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0)
|
||||
|
||||
use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0)
|
||||
|
||||
use_cpu: use_kineto=True and can be used to lower the overhead
|
||||
for GPU-only profiling (Introduced in PyTorch 1.8.0)
|
||||
|
||||
emit_nvtx: Context manager that makes every autograd operation emit an NVTX range
|
||||
Run::
|
||||
|
||||
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
|
||||
|
||||
To visualize, you can either use::
|
||||
|
||||
nvvp trace_name.prof
|
||||
torch.autograd.profiler.load_nvprof(path)
|
||||
|
||||
export_to_chrome: Wether to export the sequence of profiled operators for Chrome.
|
||||
It will generate a ``.json`` file which can be read by Chrome.
|
||||
|
||||
path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``.
|
||||
By default, it will be save where the file being is being run.
|
||||
|
||||
row_limit: Limit the number of rows in a table, `0` is a special value that
|
||||
removes the limit completely.
|
||||
|
||||
sort_by_key: Keys to sort out profiled table
|
||||
|
||||
profiled_functions: list of profiled functions which will create a context manager on.
|
||||
Any other will be pass through.
|
||||
|
||||
local_rank: When running in distributed setting, local_rank is used for each process
|
||||
to write to their own file if `output_fname` is provided.
|
||||
"""
|
||||
|
||||
self.profiled_actions = {}
|
||||
self.enabled = enabled
|
||||
self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS
|
||||
self.use_cuda = use_cuda
|
||||
self.record_shapes = record_shapes
|
||||
self.profile_memory = profile_memory
|
||||
self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total")
|
||||
self.with_stack = with_stack
|
||||
self.group_by_input_shapes = group_by_input_shapes and record_shapes
|
||||
self.use_kineto = use_kineto
|
||||
self.use_cpu = use_cpu
|
||||
self.row_limit = row_limit
|
||||
self.emit_nvtx = emit_nvtx
|
||||
self.export_to_chrome = export_to_chrome
|
||||
self.path_to_export_trace = path_to_export_trace
|
||||
|
||||
if export_to_chrome and path_to_export_trace is None:
|
||||
rank_zero_warn(
|
||||
"The exported trace would be save locally as `path_to_export_trace` is empty."
|
||||
" Note: Each functions will generate its own traced file.")
|
||||
|
||||
if self.sort_by_key not in self.AVAILABLE_SORT_KEYS:
|
||||
raise MisconfigurationException(
|
||||
f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. ")
|
||||
|
||||
self.profiled_actions = {}
|
||||
self.context_names = {}
|
||||
self.running_stack = []
|
||||
self.profiler = None
|
||||
|
||||
self.output_fname = output_filename
|
||||
self.output_file = None
|
||||
if local_rank is not None:
|
||||
self.on_train_start(local_rank=local_rank)
|
||||
self.on_train_start = super().on_train_start
|
||||
|
||||
def on_train_start(self, local_rank: Optional[str] = None):
|
||||
self.local_rank = local_rank
|
||||
|
||||
# when logging to `log.info`, only perform profiling on rank 0
|
||||
if local_rank != 0 and self.output_fname is None:
|
||||
self.wrap_functions_into_rank_zero_only()
|
||||
|
||||
if self.output_fname:
|
||||
if local_rank is not None:
|
||||
if '.txt' not in self.output_fname:
|
||||
raise MisconfigurationException("Log file should be .txt file.")
|
||||
|
||||
self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt")
|
||||
|
||||
fs = get_filesystem(self.output_fname)
|
||||
self.output_file = fs.open(self.output_fname, "w")
|
||||
|
||||
streaming_out = [self.output_file.write] if self.output_file else [log.info]
|
||||
super().__init__(output_streams=streaming_out)
|
||||
|
||||
def wrap_functions_into_rank_zero_only(self):
|
||||
self.start = rank_zero_only(self.start)
|
||||
self.stop = rank_zero_only(self.stop)
|
||||
self.summary = rank_zero_only(self.summary)
|
||||
self.describe = rank_zero_only(self.describe)
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
if action_name not in self.profiled_functions:
|
||||
return
|
||||
|
||||
if len(self.running_stack) > 0:
|
||||
self._stop(self.running_stack[-1])
|
||||
self.running_stack.append(action_name)
|
||||
|
||||
self.context_names[action_name] = "/".join(self.running_stack)
|
||||
|
||||
self._start(action_name)
|
||||
|
||||
def _start(self, action_name: str) -> None:
|
||||
if self.emit_nvtx:
|
||||
self._create_profiler(action_name, torch.cuda.profiler.profile, enter=False)
|
||||
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
|
||||
else:
|
||||
self._create_profiler(action_name, torch.autograd.profiler.profile)
|
||||
|
||||
def _create_profiler(self, action_name, profiler, enter=True):
|
||||
init_args = inspect.signature(profiler.__init__).parameters
|
||||
profiler_args = {
|
||||
k: v for k, v in vars(self).items() if k in init_args
|
||||
}
|
||||
pr = profiler(**profiler_args)
|
||||
if enter:
|
||||
pr = pr.__enter__()
|
||||
self.profiler = pr
|
||||
|
||||
def _stop(self, action_name: str) -> None:
|
||||
if self.profiler is None:
|
||||
return
|
||||
|
||||
self.profiler.__exit__(
|
||||
exc_type=None,
|
||||
exc_val=None,
|
||||
exc_tb=None
|
||||
)
|
||||
|
||||
function_events = self.profiler.function_events
|
||||
self.profiler = None
|
||||
for name in self.running_stack:
|
||||
if name not in self.profiled_actions:
|
||||
self.profiled_actions[name] = function_events
|
||||
else:
|
||||
self.profiled_actions[name] += function_events
|
||||
|
||||
def stop(self, action_name: str) -> None:
|
||||
if action_name not in self.profiled_functions:
|
||||
return
|
||||
|
||||
if len(self.running_stack) == 0 or self.running_stack[-1] != action_name:
|
||||
raise ValueError( # pragma: no-cover
|
||||
f"Attempting to stop recording an action ({action_name}) which was never started."
|
||||
)
|
||||
self._stop(action_name)
|
||||
self.running_stack.pop()
|
||||
# restore running profiler
|
||||
if len(self.running_stack) > 0:
|
||||
self._start(self.running_stack[-1])
|
||||
|
||||
def summary(self) -> str:
|
||||
recorded_stats = {}
|
||||
output_string = ''
|
||||
local_rank = '0' if self.local_rank is None else self.local_rank
|
||||
|
||||
if not self.enabled:
|
||||
return output_string
|
||||
|
||||
for action_name, function_events in self.profiled_actions.items():
|
||||
|
||||
# next line is a workaround for a pytorch issue (fixed on master, still present
|
||||
# on 1.7). Without it the code fails with `AssertionError: There is already a CPU
|
||||
# parent event for detach`
|
||||
function_events.populate_cpu_children = lambda: None
|
||||
|
||||
if self.export_to_chrome:
|
||||
filename = f"{action_name}_{local_rank}_trace.json"
|
||||
path_to_trace = filename if self.path_to_export_trace is None \
|
||||
else os.path.join(self.path_to_export_trace, filename)
|
||||
function_events.export_chrome_trace(path_to_trace)
|
||||
|
||||
if self.emit_nvtx:
|
||||
return output_string
|
||||
|
||||
else:
|
||||
table = function_events.key_averages(
|
||||
group_by_input_shapes=self.group_by_input_shapes).table(
|
||||
sort_by=self.sort_by_key,
|
||||
row_limit=self.row_limit)
|
||||
recorded_stats[action_name] = table
|
||||
|
||||
# log to standard out
|
||||
output_string = f"{os.linesep}Profiler Report{os.linesep}"
|
||||
for action, stats in recorded_stats.items():
|
||||
output_string += (
|
||||
f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}"
|
||||
)
|
||||
|
||||
return output_string
|
||||
|
||||
def describe(self):
|
||||
"""Logs a profile report after the conclusion of the training run."""
|
||||
super().describe()
|
||||
if self.output_file:
|
||||
self.output_file.flush()
|
||||
|
||||
def __del__(self):
|
||||
"""Close profiler's stream."""
|
||||
if self.output_file:
|
||||
self.output_file.close()
|
||||
|
|
|
@ -14,13 +14,20 @@
|
|||
|
||||
from typing import Union
|
||||
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, BaseProfiler, PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.profiler import (
|
||||
AdvancedProfiler,
|
||||
BaseProfiler,
|
||||
PassThroughProfiler,
|
||||
PyTorchProfiler,
|
||||
SimpleProfiler,
|
||||
)
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
PROFILERS = {
|
||||
"simple": SimpleProfiler,
|
||||
"advanced": AdvancedProfiler,
|
||||
"pytorch": PyTorchProfiler
|
||||
}
|
||||
|
||||
|
||||
|
@ -51,3 +58,7 @@ class ProfilerConnector:
|
|||
raise ValueError("When passing string value for the `profiler` parameter of"
|
||||
" `Trainer`, it can only be 'simple' or 'advanced'")
|
||||
self.trainer.profiler = profiler or PassThroughProfiler()
|
||||
|
||||
def on_train_start(self, trainer):
|
||||
local_rank = trainer.local_rank if trainer.world_size > 1 else None
|
||||
self.trainer.profiler.on_train_start(local_rank)
|
||||
|
|
|
@ -164,10 +164,12 @@ class EvaluationLoop(object):
|
|||
# run actual test step
|
||||
if self.testing:
|
||||
model_ref._current_fx_name = "test_step"
|
||||
output = self.trainer.accelerator_backend.test_step(args)
|
||||
with self.trainer.profiler.profile("test_step"):
|
||||
output = self.trainer.accelerator_backend.test_step(args)
|
||||
else:
|
||||
model_ref._current_fx_name = "validation_step"
|
||||
output = self.trainer.accelerator_backend.validation_step(args)
|
||||
with self.trainer.profiler.profile("validation_step"):
|
||||
output = self.trainer.accelerator_backend.validation_step(args)
|
||||
|
||||
# capture any logged information
|
||||
self.trainer.logger_connector.cache_logged_metrics()
|
||||
|
|
|
@ -111,6 +111,9 @@ class TrainLoop:
|
|||
# hook
|
||||
self.trainer.call_hook("on_train_start")
|
||||
|
||||
# provide rank to profiler
|
||||
self.trainer.profile_connector.on_train_start(self.trainer)
|
||||
|
||||
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
# bind logger and other properties
|
||||
self.trainer.model_connector.copy_trainer_model_properties(model)
|
||||
|
@ -339,7 +342,8 @@ class TrainLoop:
|
|||
# manually capture logged metrics
|
||||
model_ref._current_fx_name = 'training_step'
|
||||
model_ref._results = Result()
|
||||
training_step_output = self.trainer.accelerator_backend.training_step(args)
|
||||
with self.trainer.profiler.profile("training_step"):
|
||||
training_step_output = self.trainer.accelerator_backend.training_step(args)
|
||||
self.trainer.logger_connector.cache_logged_metrics()
|
||||
|
||||
self._check_training_step_output(training_step_output)
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
|
||||
|
|
|
@ -40,8 +40,10 @@ class PreCalculatedModel(BoringModel):
|
|||
|
||||
def __init__(self, precision: int = 32):
|
||||
super().__init__()
|
||||
self.layer = nn.Linear(32, 1000, bias=False) # 32K params
|
||||
self.layer1 = nn.Linear(1000, 218, bias=False) # 218K params
|
||||
# 32K params
|
||||
self.layer = nn.Linear(32, 1000, bias=False)
|
||||
# 218K params
|
||||
self.layer1 = nn.Linear(1000, 218, bias=False)
|
||||
|
||||
# calculate model size based on precision.
|
||||
self.pre_calculated_model_size = 1.0 / (32 / precision)
|
||||
|
|
|
@ -23,3 +23,4 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent
|
|||
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection
|
||||
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
|
||||
python ${DEFAULTS} tests/trainer/logging_process/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
|
||||
python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp
|
||||
|
|
|
@ -740,6 +740,7 @@ def test_logging_sync_dist_true_ddp(tmpdir):
|
|||
weights_summary=None,
|
||||
accelerator="ddp",
|
||||
gpus=2,
|
||||
profiler="pytorch"
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import pickle
|
|||
import sys
|
||||
from argparse import Namespace
|
||||
from copy import deepcopy
|
||||
from distutils.version import LooseVersion
|
||||
from pathlib import Path
|
||||
from unittest.mock import ANY, call, patch
|
||||
|
||||
|
@ -30,7 +31,7 @@ from pytorch_lightning import Callback, LightningModule, Trainer
|
|||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
|
||||
|
@ -39,6 +40,12 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from tests.base import BoringModel, EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pytorch_profiler(tmpdir):
|
||||
profiler = PyTorchProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"), local_rank=0)
|
||||
return profiler
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url_ckpt", [True, False])
|
||||
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
||||
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
|
||||
|
@ -1421,6 +1428,7 @@ def test_log_every_n_steps(log_metrics_mock, tmpdir, train_batches, max_steps, l
|
|||
('simple', SimpleProfiler),
|
||||
('Simple', SimpleProfiler),
|
||||
('advanced', AdvancedProfiler),
|
||||
('pytorch', PyTorchProfiler),
|
||||
])
|
||||
def test_trainer_profiler_correct_args(profiler, expected):
|
||||
kwargs = {'profiler': profiler} if profiler is not None else {}
|
||||
|
@ -1441,3 +1449,105 @@ def test_trainer_profiler_incorrect_arg_type(profiler):
|
|||
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
|
||||
r" are valid values for `Trainer`'s `profiler` parameter. *"):
|
||||
Trainer(profiler=profiler)
|
||||
|
||||
|
||||
def test_pytorch_profiler_describe(pytorch_profiler):
|
||||
"""Ensure the profiler won't fail when reporting the summary."""
|
||||
with pytorch_profiler.profile("test_step"):
|
||||
pass
|
||||
|
||||
# log to stdout and print to file
|
||||
pytorch_profiler.describe()
|
||||
data = Path(pytorch_profiler.output_fname).read_text()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_pytorch_profiler_value_errors(pytorch_profiler):
|
||||
"""Ensure errors are raised where expected."""
|
||||
|
||||
action = "test_step"
|
||||
with pytest.raises(ValueError):
|
||||
pytorch_profiler.stop(action)
|
||||
|
||||
pytorch_profiler.start(action)
|
||||
pytorch_profiler.stop(action)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
@pytest.mark.parametrize("use_output_filename", [False, True])
|
||||
def test_pytorch_profiler_trainer_ddp(tmpdir, use_output_filename):
|
||||
"""Ensure that the profiler can be given to the training and default step are properly recorded. """
|
||||
|
||||
if use_output_filename:
|
||||
output_filename = os.path.join(tmpdir, "profiler.txt")
|
||||
else:
|
||||
output_filename = None
|
||||
|
||||
profiler = PyTorchProfiler(output_filename=output_filename)
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
profiler=profiler,
|
||||
accelerator="ddp",
|
||||
gpus=2
|
||||
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
enabled = use_output_filename or not use_output_filename and profiler.local_rank == 0
|
||||
|
||||
if enabled:
|
||||
assert len(profiler.summary()) > 0
|
||||
assert set(profiler.profiled_actions.keys()) == {'training_step_and_backward', 'validation_step'}
|
||||
else:
|
||||
assert profiler.summary() is None
|
||||
assert set(profiler.profiled_actions.keys()) == set()
|
||||
|
||||
if use_output_filename:
|
||||
profiler.describe()
|
||||
data = Path(profiler.output_fname).read_text()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
def test_pytorch_profiler_nested(tmpdir):
|
||||
"""Ensure that the profiler handles nested context"""
|
||||
|
||||
pytorch_profiler = PyTorchProfiler(
|
||||
profiled_functions=["a", "b", "c"],
|
||||
use_cuda=False,
|
||||
output_filename=os.path.join(tmpdir, "profiler.txt"))
|
||||
|
||||
with pytorch_profiler.profile("a"):
|
||||
a = torch.ones(42)
|
||||
with pytorch_profiler.profile("b"):
|
||||
b = torch.zeros(42)
|
||||
with pytorch_profiler.profile("c"):
|
||||
_ = a + b
|
||||
|
||||
pa = pytorch_profiler.profiled_actions
|
||||
|
||||
# From PyTorch 1.6.0, more operation are being traced.
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
prefix_to_remove = "aten::" if LooseVersion(torch.__version__) >= LooseVersion("1.7.1") else ''
|
||||
|
||||
expected_a = ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty']
|
||||
assert [e.name.replace(prefix_to_remove, '') for e in pa['a']] == expected_a
|
||||
|
||||
expected_b = ['zeros', 'empty', 'zero_', 'fill_']
|
||||
assert [e.name.replace(prefix_to_remove, '') for e in pa['b']] == expected_b
|
||||
|
||||
expected_c = ['add', 'empty']
|
||||
assert [e.name.replace(prefix_to_remove, '') for e in pa['c']] == expected_c
|
||||
|
||||
else:
|
||||
expected_a = ['add']
|
||||
assert [e.name for e in pa['a']] == expected_a
|
||||
|
||||
expected_b = []
|
||||
assert [e.name for e in pa['b']] == expected_b
|
||||
|
||||
expected_c = ['add']
|
||||
assert [e.name for e in pa['c']] == expected_c
|
||||
|
|
Loading…
Reference in New Issue