Refactor PyTorch profiler 4/5 (#6349)

Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Carlos Mocholí 2021-03-23 18:13:29 +01:00 committed by GitHub
parent 3cf0c3117a
commit 51b10f78f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 376 additions and 218 deletions

View File

@ -52,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `AbstractProfiler` interface ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
- Added support for including module names for forward in the autograd trace of `PyTorchProfiler` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
@ -72,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed profilers to save separate report files per state and rank ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
### Deprecated
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
@ -83,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Profiler(output_filename)` in favor of `dirpath` and `filename` ([#6621](https://github.com/PyTorchLightning/pytorch-lightning/pull/6621))
- Deprecated `PytorchProfiler(profiled_functions)` in favor of `record_functions` ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),
[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),

View File

@ -126,7 +126,7 @@ class BaseProfiler(AbstractProfiler):
filename += f"{self._stage}-"
filename += str(self.filename)
if self._local_rank is not None:
filename += f"-{self.local_rank}"
filename += f"-{self._local_rank}"
filename += ".txt"
return filename
@ -134,8 +134,7 @@ class BaseProfiler(AbstractProfiler):
if self._write_stream is not None:
return
if self.filename:
dirpath = self.dirpath or self._log_dir
filepath = os.path.join(dirpath, self._prepare_filename())
filepath = os.path.join(self.dirpath, self._prepare_filename())
fs = get_filesystem(filepath)
file = fs.open(filepath, "a")
self._output_file = file
@ -175,8 +174,7 @@ class BaseProfiler(AbstractProfiler):
self._stage = stage
self._local_rank = local_rank
self._log_dir = log_dir
if self.dirpath is None:
self.dirpath = self._log_dir
self.dirpath = self.dirpath or log_dir
def teardown(self, stage: Optional[str] = None) -> None:
"""
@ -202,8 +200,8 @@ class BaseProfiler(AbstractProfiler):
raise NotImplementedError
@property
def local_rank(self):
return '0' if self._local_rank is None else self._local_rank
def local_rank(self) -> int:
return 0 if self._local_rank is None else self._local_rank
class PassThroughProfiler(BaseProfiler):

View File

@ -12,25 +12,92 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Profiler to check if there are any bottlenecks in your code."""
import inspect
import logging
import os
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING, Union
import torch
from torch import nn, Tensor
from torch.autograd.profiler import record_function
from pytorch_lightning.profiler.profilers import BaseProfiler
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if TYPE_CHECKING:
from torch.autograd.profiler import EventList
from torch.utils.hooks import RemovableHandle
from pytorch_lightning.core.lightning import LightningModule
log = logging.getLogger(__name__)
_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx]
class RegisterRecordFunction:
"""
While profiling autograd operations, this class will add labels for module names around the forward function.
The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows:
Example::
from pytorch_lightning.profilers import PyTorchProfiler
profiler = PyTorchProfiler(record_module_names=False)
Trainer(profiler=profiler)
It can be used outside of Lightning as follows:
Example::
from pytorch_lightning import Trainer, seed_everything
with RegisterRecordFunction(model):
out = model(batch)
"""
def __init__(self, model: nn.Module) -> None:
self._model = model
self._records: Dict[str, record_function] = {}
self._handles: Dict[str, List['RemovableHandle']] = {}
def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor:
record = record_function(record_name)
record.__enter__()
self._records[record_name] = record
return input
def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor:
self._records[record_name].__exit__(None, None, None)
return output
def __enter__(self) -> None:
for module_name, module in self._model.named_modules():
if module_name:
full_name = f"{type(module).__module__}.{type(module).__name__}"
record_name = f"{full_name}: {module_name}"
pre_forward_handle = module.register_forward_pre_hook(
partial(self._start_recording_forward, record_name=record_name)
)
post_forward_handle = module.register_forward_hook(
partial(self._stop_recording_forward, record_name=record_name)
)
self._handles[module_name] = [pre_forward_handle, post_forward_handle]
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
for handles in self._handles.values():
for h in handles:
h.remove()
self._handles = {}
class PyTorchProfiler(BaseProfiler):
PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step")
RECORD_FUNCTIONS = (
"training_step_and_backward", "training_step", "backward", "validation_step", "test_step", "predict_step"
)
AVAILABLE_SORT_KEYS = (
"cpu_time",
"cuda_time",
@ -42,27 +109,24 @@ class PyTorchProfiler(BaseProfiler):
"self_cuda_memory_usage",
"count",
)
START_RECORD_FUNCTIONS = ('on_train_start', 'on_validation_start', 'on_test_start', 'on_predict_start')
def __init__(
self,
dirpath: Optional[Union[str, Path]] = None,
filename: Optional[str] = None,
enabled: bool = True,
use_cuda: bool = False,
record_shapes: bool = False,
profile_memory: bool = False,
group_by_input_shapes: bool = False,
with_stack: bool = False,
use_kineto: bool = False,
use_cpu: bool = True,
emit_nvtx: bool = False,
export_to_chrome: bool = False,
path_to_export_trace: str = None,
export_to_chrome: bool = True,
path_to_export_trace: Optional[str] = None,
row_limit: int = 20,
sort_by_key: Optional[str] = None,
record_functions: List[str] = None,
record_module_names: bool = True,
profiled_functions: Optional[List] = None,
output_filename: Optional[str] = None,
):
**profiler_kwargs: Any,
) -> None:
"""
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of
different operators inside your model - both on the CPU and GPU
@ -75,24 +139,8 @@ class PyTorchProfiler(BaseProfiler):
filename: If present, filename where the profiler results will be saved instead of printing to stdout.
The ``.txt`` extension will be used automatically.
enabled: Setting this to False makes this context manager a no-op.
use_cuda: Enables timing of CUDA events as well using the cudaEvent API.
Adds approximately 4us of overhead to each tensor operation.
record_shapes: If shapes recording is set, information about input dimensions will be collected.
profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0)
group_by_input_shapes: Include operator input shapes and group calls by shape.
with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0)
use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0)
use_cpu: use_kineto=True and can be used to lower the overhead
for GPU-only profiling (Introduced in PyTorch 1.8.0)
emit_nvtx: Context manager that makes every autograd operation emit an NVTX range
Run::
@ -103,164 +151,189 @@ class PyTorchProfiler(BaseProfiler):
nvvp trace_name.prof
torch.autograd.profiler.load_nvprof(path)
export_to_chrome: Wether to export the sequence of profiled operators for Chrome.
export_to_chrome: Whether to export the sequence of profiled operators for Chrome.
It will generate a ``.json`` file which can be read by Chrome.
path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``.
By default, it will be save where the file being is being run.
row_limit: Limit the number of rows in a table, `0` is a special value that
row_limit: Limit the number of rows in a table, ``-1`` is a special value that
removes the limit completely.
sort_by_key: Keys to sort out profiled table
sort_by_key: Attribute used to sort entries. By default
they are printed in the same order as they were registered.
Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``,
``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``,
``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``.
profiled_functions: list of profiled functions which will create a context manager on.
record_functions: list of profiled functions which will create a context manager on.
Any other will be pass through.
record_module_names: Whether to add module names while recording autograd operation.
profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
Raises:
MisconfigurationException:
If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``.
ValueError:
If you attempt to stop recording an action which was never started.
"""
super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename)
self.profiled_actions = {}
self.enabled = enabled
self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS
self.use_cuda = use_cuda
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total")
self.with_stack = with_stack
self.group_by_input_shapes = group_by_input_shapes and record_shapes
self.use_kineto = use_kineto
self.use_cpu = use_cpu
self.row_limit = row_limit
self.emit_nvtx = emit_nvtx
self.export_to_chrome = export_to_chrome
self.path_to_export_trace = path_to_export_trace
record_functions = self.__deprecation_check(profiled_functions, record_functions)
if export_to_chrome and path_to_export_trace is None:
self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False)
self._emit_nvtx = emit_nvtx
self._export_to_chrome = export_to_chrome
self._path_to_export_trace = path_to_export_trace
self._row_limit = row_limit
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
self._record_functions_start = set(record_functions + list(self.START_RECORD_FUNCTIONS))
self._record_functions = set(record_functions + list(self.RECORD_FUNCTIONS))
self._record_module_names = record_module_names
self._profiler_kwargs = profiler_kwargs
self.profiler: Optional[_PROFILER] = None
self.function_events: Optional['EventList'] = None
self._lightning_module: Optional['LightningModule'] = None # set by ProfilerConnector
self._register: Optional[RegisterRecordFunction] = None
self._parent_profiler: Optional[_PROFILER] = None
self._recording_map: Dict[str, record_function] = {}
if self._export_to_chrome and self._path_to_export_trace is None:
rank_zero_warn(
"The exported trace would be save locally as `path_to_export_trace` is empty."
"The exported trace would be saved locally as `path_to_export_trace` is None."
" Note: Each functions will generate its own traced file."
)
if self.sort_by_key not in self.AVAILABLE_SORT_KEYS:
if self._sort_by_key not in self.AVAILABLE_SORT_KEYS:
raise MisconfigurationException(
f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
)
self.profiled_actions = {}
self.context_names = {}
self.running_stack = []
self.profiler = None
def __deprecation_check(
self,
profiled_functions: Optional[List[str]],
record_functions: Optional[List[str]],
) -> List[str]:
if record_functions is None:
record_functions = []
super().__init__(dirpath=dirpath, filename=filename, output_filename=output_filename)
if profiled_functions is not None:
rank_zero_warn(
"`PyTorchProfiler.profiled_functions` has been renamed to"
" `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning
)
if not record_functions:
record_functions += profiled_functions
else:
raise MisconfigurationException(
"You set `PytorchProfiler.profiled_functions` and `PyTorchProfiler.record_functions`."
" Please use only the later."
)
return record_functions
def setup(
self,
stage: Optional[str] = None,
local_rank: Optional[int] = None,
log_dir: Optional[str] = None
self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None
) -> None:
super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir)
# if the user didn't provide `path_to_export_trace`,
# set it as TensorBoardLogger log_dir if exists
if self.path_to_export_trace is None:
self.path_to_export_trace = log_dir
if self._path_to_export_trace is None:
self._path_to_export_trace = log_dir
def start(self, action_name: str) -> None:
if action_name not in self.profiled_functions:
return
if self.profiler is None and action_name in self._record_functions_start:
if len(self.running_stack) > 0:
self._stop(self.running_stack[-1])
self.running_stack.append(action_name)
# close profiler if it is already opened. might happen if 2 profilers
# are created and the first one did not call `describe`
try:
torch.autograd._disable_profiler() # noqa
except (AttributeError, RuntimeError):
pass
self.context_names[action_name] = "/".join(self.running_stack)
self._create_profilers()
self._start(action_name)
self.profiler.__enter__()
if self._parent_profiler is not None:
self._parent_profiler.__enter__()
if self._register is not None:
self._register.__enter__()
def _start(self, action_name: str) -> None:
if self.emit_nvtx:
self._parent_profiler = self._create_profiler(action_name, torch.cuda.profiler.profile, enter=True)
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
else:
self._create_profiler(action_name, torch.autograd.profiler.profile)
def _create_profiler(self, action_name, profiler, enter=True):
init_args = inspect.signature(profiler.__init__).parameters
profiler_args = {k: v for k, v in vars(self).items() if k in init_args}
pr = profiler(**profiler_args)
if enter:
out_pr = pr.__enter__()
if out_pr is not None:
pr = out_pr
self.profiler = pr
return self.profiler
def _stop(self, action_name: str) -> None:
if self.profiler is None:
return
self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None)
if isinstance(self.profiler, torch.autograd.profiler.emit_nvtx):
# when running ``emit_nvtx``, PyTorch requires 2 context manager.
# The parent_profiler is being closed too.
self._parent_profiler.__exit__(None, None, None)
self._parent_profiler = None
return
function_events = self.profiler.function_events
self.profiler = None
for name in self.running_stack:
if name not in self.profiled_actions:
self.profiled_actions[name] = function_events
else:
self.profiled_actions[name] += function_events
if (
self.profiler is not None and action_name in self._record_functions
and action_name not in self._recording_map
):
recording = record_function(action_name)
recording.__enter__()
self._recording_map[action_name] = recording
def stop(self, action_name: str) -> None:
if action_name not in self.profiled_functions:
return
if len(self.running_stack) == 0 or self.running_stack[-1] != action_name:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
self._stop(action_name)
self.running_stack.pop()
# restore running profiler
if len(self.running_stack) > 0:
self._start(self.running_stack[-1])
if action_name in self._recording_map:
self._recording_map[action_name].__exit__(None, None, None)
del self._recording_map[action_name]
def summary(self) -> str:
recorded_stats = {}
output_string = ''
if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx:
return ""
if not self.enabled:
return output_string
self._delete_profilers()
for action_name, function_events in self.profiled_actions.items():
if not self.function_events:
return ""
# next line is a workaround for a pytorch issue (fixed on master, still present
# on 1.7). Without it the code fails with `AssertionError: There is already a CPU
# parent event for detach`
function_events.populate_cpu_children = lambda: None
if self._export_to_chrome:
filename = f"{self.local_rank}_trace.json"
path_to_trace = (
filename if self._path_to_export_trace is None else os.path.join(self._path_to_export_trace, filename)
)
self.function_events.export_chrome_trace(path_to_trace)
if self.export_to_chrome:
filename = f"{action_name}_{self.local_rank}_trace.json"
path_to_trace = filename if self.path_to_export_trace is None \
else os.path.join(self.path_to_export_trace, filename)
function_events.export_chrome_trace(path_to_trace)
data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes)
table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit)
if self.emit_nvtx:
return output_string
else:
data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes)
table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit)
recorded_stats[action_name] = table
recorded_stats = {"records": table}
return self._stats_to_str(recorded_stats)
def _create_profilers(self) -> None:
if self._emit_nvtx:
self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile)
self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx)
else:
self._parent_profiler = None
self.profiler = self._create_profiler(torch.autograd.profiler.profile)
if self._record_module_names and self._lightning_module is not None:
self._register = RegisterRecordFunction(self._lightning_module)
def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER:
init_parameters = inspect.signature(profiler.__init__).parameters
kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters}
return profiler(**kwargs)
def _cache_functions_events(self):
if not self._emit_nvtx:
self.function_events = self.profiler.function_events
def _delete_profilers(self) -> None:
if self.profiler is not None:
self.profiler.__exit__(None, None, None)
self._cache_functions_events()
self.profiler = None
if self._parent_profiler is not None:
self._parent_profiler.__exit__(None, None, None)
self._parent_profiler = None
if self._register is not None:
self._register.__exit__(None, None, None)
self._register = None
def teardown(self, stage: Optional[str] = None) -> None:
self._delete_profilers()
for k in self._recording_map:
self.stop(k)
self._recording_map = {}
super().teardown(stage=stage)

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from typing import Union
from weakref import proxy
from pytorch_lightning.profiler import (
AdvancedProfiler,
@ -57,4 +57,5 @@ class ProfilerConnector:
def setup(self) -> None:
trainer = self.trainer
local_rank = trainer.local_rank if trainer.world_size > 1 else None
trainer.profiler.lightning_module = proxy(trainer.lightning_module)
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)

View File

@ -44,6 +44,8 @@ class PredictLoop(object):
model_ref.on_predict_model_eval()
def setup(self, model, max_batches, dataloaders):
self.trainer.call_hook("on_predict_start")
# copy properties for forward overrides
self.trainer.model_connector.copy_trainer_model_properties(model)
@ -86,6 +88,8 @@ class PredictLoop(object):
return
def on_predict_epoch_end(self):
self.trainer.profiler.describe()
self.trainer._progress_bar_callback.on_predict_end(self.trainer, self.trainer.lightning_module)
results = self._predictions

View File

@ -743,7 +743,7 @@ class TrainLoop:
# backward pass
if result is not None:
with self.trainer.profiler.profile("model_backward"):
with self.trainer.profiler.profile("backward"):
self.backward(result, optimizer, opt_idx)
# hook - call this hook only

View File

@ -68,6 +68,7 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')

View File

@ -47,6 +47,7 @@ def test_model_torch_save_ddp_cpu(tmpdir):
max_epochs=num_epochs,
accelerator="ddp_cpu",
num_processes=2,
logger=False,
)
temp_path = os.path.join(tmpdir, 'temp.pt')
trainer.fit(model)

View File

@ -81,6 +81,11 @@ def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
trainer.save_checkpoint(filepath)
def test_v1_5_0_legacy_profiler_argument():
with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"):
PyTorchProfiler(profiled_functions=[])
def test_v1_5_0_running_sanity_check():
trainer = Trainer()
with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'):

View File

@ -13,6 +13,7 @@
# limitations under the License.
import logging
import os
import platform
import time
from copy import deepcopy
from distutils.version import LooseVersion
@ -24,6 +25,9 @@ import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler
from pytorch_lightning.profiler.pytorch import RegisterRecordFunction
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@ -126,10 +130,10 @@ def test_simple_profiler_log_dir(tmpdir):
)
trainer.fit(model)
expected = profiler.dirpath
expected = tmpdir / "lightning_logs" / "version_0"
assert trainer.log_dir == expected
assert profiler._log_dir == trainer.log_dir
assert Path(os.path.join(profiler.dirpath, "fit-profiler.txt")).exists()
assert expected.join("fit-profiler.txt").exists()
@RunIf(skip_windows=True)
@ -264,8 +268,8 @@ def pytorch_profiler(tmpdir):
def test_pytorch_profiler_describe(pytorch_profiler):
"""Ensure the profiler won't fail when reporting the summary."""
with pytorch_profiler.profile("test_step"):
pass
with pytorch_profiler.profile("on_test_start"):
torch.tensor(0)
# log to stdout and print to file
pytorch_profiler.describe()
@ -274,15 +278,10 @@ def test_pytorch_profiler_describe(pytorch_profiler):
assert len(data) > 0
def test_pytorch_profiler_value_errors(pytorch_profiler):
def test_pytorch_profiler_raises(pytorch_profiler):
"""Ensure errors are raised where expected."""
action = "test_step"
with pytest.raises(ValueError):
pytorch_profiler.stop(action)
pytorch_profiler.start(action)
pytorch_profiler.stop(action)
with pytest.raises(MisconfigurationException, match="profiled_functions` and `PyTorchProfiler.record"):
PyTorchProfiler(profiled_functions=["a"], record_functions=["b"])
@RunIf(min_torch="1.6.0")
@ -299,9 +298,8 @@ def test_advanced_profiler_cprofile_deepcopy(tmpdir):
@RunIf(min_gpus=2, special=True)
def test_pytorch_profiler_trainer_ddp(tmpdir):
def test_pytorch_profiler_trainer_ddp(tmpdir, pytorch_profiler):
"""Ensure that the profiler can be given to the training and default step are properly recorded. """
pytorch_profiler = PyTorchProfiler(dirpath=None, filename="profiler")
model = BoringModel()
trainer = Trainer(
max_epochs=1,
@ -314,17 +312,68 @@ def test_pytorch_profiler_trainer_ddp(tmpdir):
)
trainer.fit(model)
assert len(pytorch_profiler.summary()) > 0
assert set(pytorch_profiler.profiled_actions) == {'training_step_and_backward', 'validation_step'}
expected = ('validation_step', 'training_step_and_backward', 'training_step', 'backward')
for name in expected:
assert sum(e.name == name for e in pytorch_profiler.function_events)
files = sorted(f for f in os.listdir(pytorch_profiler.dirpath) if "fit" in f)
rank = int(os.getenv("LOCAL_RANK", "0"))
expected = f"fit-profiler-{rank}.txt"
assert files[rank] == expected
files = set(os.listdir(pytorch_profiler.dirpath))
expected = f"fit-profiler-{trainer.local_rank}.txt"
assert expected in files
path = os.path.join(pytorch_profiler.dirpath, expected)
data = Path(path).read_text("utf-8")
assert len(data) > 0
assert Path(path).read_text()
def test_pytorch_profiler_trainer_test(tmpdir, pytorch_profiler):
"""Ensure that the profiler can be given to the trainer and test step are properly recorded. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_test_batches=2,
profiler=pytorch_profiler,
)
trainer.test(model)
assert sum(e.name == 'test_step' for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"test-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
def test_pytorch_profiler_trainer_predict(tmpdir, pytorch_profiler):
"""Ensure that the profiler can be given to the trainer and predict function are properly recorded. """
model = BoringModel()
model.predict_dataloader = model.train_dataloader
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_test_batches=2,
profiler=pytorch_profiler,
)
trainer.predict(model)
assert sum(e.name == 'predict_step' for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"predict-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
def test_pytorch_profiler_trainer_validate(tmpdir, pytorch_profiler):
"""Ensure that the profiler can be given to the trainer and validate function are properly recorded. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=2,
profiler=pytorch_profiler,
)
trainer.validate(model)
assert sum(e.name == 'validation_step' for e in pytorch_profiler.function_events)
path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt"
assert path.read_text("utf-8")
def test_pytorch_profiler_nested(tmpdir):
@ -341,34 +390,31 @@ def test_pytorch_profiler_nested(tmpdir):
with pytorch_profiler.profile("c"):
_ = a + b
pa = pytorch_profiler.profiled_actions
pytorch_profiler.describe()
# From PyTorch 1.8.0, less operation are being traced.
if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"):
expected_ = {
'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'add'],
'b': ['zeros', 'empty', 'zero_'],
'c': ['add'],
}
# From PyTorch 1.6.0, more operation are being traced.
elif LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
expected_ = {
'a': ['ones', 'empty', 'fill_', 'zeros', 'empty', 'zero_', 'fill_', 'add', 'empty'],
'b': ['zeros', 'empty', 'zero_', 'fill_'],
'c': ['add', 'empty'],
}
events_name = {e.name for e in pytorch_profiler.function_events}
if platform.system() == "Windows":
expected = {'a', 'add', 'b', 'c', 'profiler::_record_function_enter', 'profiler::_record_function_exit'}
else:
expected_ = {
'a': ['add'],
'b': [],
'c': ['add'],
expected = {
'signed char', 'add', 'profiler::_record_function_exit', 'bool', 'char', 'profiler::_record_function_enter'
}
for n in ('a', 'b', 'c'):
pa[n] = [e.name for e in pa[n]]
if LooseVersion(torch.__version__) >= LooseVersion("1.7.1"):
pa[n] = [e.replace("aten::", "") for e in pa[n]]
assert pa[n] == expected_[n]
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
expected = {'add', 'zeros', 'ones', 'zero_', 'b', 'fill_', 'c', 'a', 'empty'}
if LooseVersion(torch.__version__) >= LooseVersion("1.7.0"):
expected = {
'aten::zeros', 'aten::add', 'aten::zero_', 'c', 'b', 'a', 'aten::fill_', 'aten::empty', 'aten::ones'
}
if LooseVersion(torch.__version__) >= LooseVersion("1.8.0"):
expected = {
'aten::ones', 'a', 'aten::add', 'aten::empty', 'aten::zero_', 'b', 'c', 'aten::zeros', 'aten::fill_'
}
assert events_name == expected, (events_name, torch.__version__, platform.system())
@RunIf(min_gpus=1, special=True)
@ -387,6 +433,43 @@ def test_pytorch_profiler_nested_emit_nvtx(tmpdir):
trainer.fit(model)
@RunIf(min_torch="1.5.0")
def test_register_record_function(tmpdir):
use_cuda = torch.cuda.is_available()
pytorch_profiler = PyTorchProfiler(
export_to_chrome=False,
record_functions=["a"],
use_cuda=use_cuda,
dirpath=tmpdir,
filename="profiler",
)
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.ReLU(), torch.nn.Linear(8, 2))
model = TestModel()
input = torch.rand((1, 8))
if use_cuda:
model = model.cuda()
input = input.cuda()
with pytorch_profiler.profile("a"):
with RegisterRecordFunction(model):
model(input)
pytorch_profiler.describe()
event_names = [e.name for e in pytorch_profiler.function_events]
assert 'torch.nn.modules.container.Sequential: layer' in event_names
assert 'torch.nn.modules.linear.Linear: layer.0' in event_names
assert 'torch.nn.modules.activation.ReLU: layer.1' in event_names
assert 'torch.nn.modules.linear.Linear: layer.2' in event_names
@pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
def test_profiler_teardown(tmpdir, cls):
"""
@ -407,6 +490,9 @@ def test_profiler_teardown(tmpdir, cls):
assert profiler._output_file is None
@pytest.mark.skipif(_TORCH_GREATER_EQUAL_1_8, reason="currently not supported for PyTorch 1.8")
def test_pytorch_profiler_deepcopy(pytorch_profiler):
pytorch_profiler.start("on_train_start")
torch.tensor(1)
pytorch_profiler.describe()
assert deepcopy(pytorch_profiler)

View File

@ -80,23 +80,3 @@ def test_get_model_gpu(tmpdir):
gpus=1,
)
trainer.fit(model)
@RunIf(min_gpus=1, skip_windows=True)
def test_get_model_ddp_gpu(tmpdir):
"""
Tests that `trainer.lightning_module` extracts the model correctly when using GPU + ddp accelerators
"""
model = TrainerGetModel()
limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
gpus=1,
)
trainer.fit(model)
return 1