From f68c0909fd70e4b7bfab08d76eea646a7d41acb9 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 13 Sep 2022 18:11:45 +0200 Subject: [PATCH] Fix mypy errors attributed to `pytorch_lightning.profilers.pytorch` (#14405) * remove toml ref * fix conflicts * small fix * move assertion Co-authored-by: rohitgr7 --- pyproject.toml | 1 - src/pytorch_lightning/profilers/pytorch.py | 69 ++++++++++++++-------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dd48b8126a..777f86841a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ warn_no_return = "False" # mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' module = [ "pytorch_lightning.callbacks.progress.rich_progress", - "pytorch_lightning.profilers.pytorch", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.data", diff --git a/src/pytorch_lightning/profilers/pytorch.py b/src/pytorch_lightning/profilers/pytorch.py index c7f34fdc79..475db682d9 100644 --- a/src/pytorch_lightning/profilers/pytorch.py +++ b/src/pytorch_lightning/profilers/pytorch.py @@ -17,7 +17,7 @@ import logging import os from functools import lru_cache, partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union +from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union import torch from lightning_utilities.core.rank_zero import WarningCache @@ -42,7 +42,7 @@ if _KINETO_AVAILABLE: log = logging.getLogger(__name__) warning_cache = WarningCache() -_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] +_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx] class RegisterRecordFunction: @@ -111,13 +111,7 @@ class ScheduleWrapper: self._schedule = schedule self.reset() - def setup(self, start_action_name: str) -> None: - self._start_action_name = start_action_name - - def pre_step(self, current_action: str) -> None: - self._current_action = current_action - - def reset(self): + def reset(self) -> None: # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise. self._num_training_step = 0 self._num_validation_step = 0 @@ -132,20 +126,30 @@ class ScheduleWrapper: self._prev_schedule_action: Optional[ProfilerAction] = None self._start_action_name: Optional[str] = None + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + @property - def is_training(self): + def is_training(self) -> bool: + assert self._current_action is not None return self._current_action.endswith("training_step") @property - def is_validating(self): + def is_validating(self) -> bool: + assert self._current_action is not None return self._current_action.endswith("validation_step") @property - def is_testing(self): + def is_testing(self) -> bool: + assert self._current_action is not None return self._current_action.endswith("test_step") @property - def is_predicting(self): + def is_predicting(self) -> bool: + assert self._current_action is not None return self._current_action.endswith("predict_step") @property @@ -164,6 +168,7 @@ class ScheduleWrapper: if self.is_training: self._num_training_step += 1 elif self.is_validating: + assert self._start_action_name is not None if self._start_action_name.endswith("on_fit_start"): if self._num_training_step > 0: self._num_validation_step += 1 @@ -238,7 +243,7 @@ class PyTorchProfiler(Profiler): record_module_names: bool = True, **profiler_kwargs: Any, ) -> None: - """This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. + r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. different operators inside your model - both on the CPU and GPU @@ -276,7 +281,7 @@ class PyTorchProfiler(Profiler): record_module_names: Whether to add module names while recording autograd operation. - profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version + \**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version Raises: MisconfigurationException: @@ -298,7 +303,7 @@ class PyTorchProfiler(Profiler): self.function_events: Optional["EventList"] = None self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[_PROFILER] = None + self._parent_profiler: Optional[ContextManager] = None self._recording_map: Dict[str, record_function] = {} self._start_action_name: Optional[str] = None self._schedule: Optional[ScheduleWrapper] = None @@ -317,7 +322,7 @@ class PyTorchProfiler(Profiler): schedule = profiler_kwargs.get("schedule", None) if schedule is not None: - if not isinstance(schedule, Callable): + if not callable(schedule): raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") action = schedule(0) if not isinstance(action, ProfilerAction): @@ -337,7 +342,9 @@ class PyTorchProfiler(Profiler): self._profiler_kwargs["with_stack"] = with_stack @property - def _total_steps(self) -> int: + def _total_steps(self) -> Union[int, float]: + assert self._schedule is not None + assert self._lightning_module is not None trainer = self._lightning_module.trainer if self._schedule.is_training: return trainer.num_training_batches @@ -358,13 +365,13 @@ class PyTorchProfiler(Profiler): @staticmethod @lru_cache(1) - def _default_schedule() -> Optional[callable]: + def _default_schedule() -> Optional[Callable]: if _KINETO_AVAILABLE: # Those schedule defaults allow the profiling overhead to be negligible over training time. return torch.profiler.schedule(wait=1, warmup=1, active=3) def _default_activities(self) -> List["ProfilerActivity"]: - activities = [] + activities: List["ProfilerActivity"] = [] if not _KINETO_AVAILABLE: return activities if self._profiler_kwargs.get("use_cpu", True): @@ -411,6 +418,7 @@ class PyTorchProfiler(Profiler): return if self.profiler is not None and any(action_name.endswith(func) for func in self.STEP_FUNCTIONS): + assert isinstance(self.profiler, torch.profiler.profile) if self._schedule is not None: self._schedule.pre_step(action_name) @@ -424,11 +432,11 @@ class PyTorchProfiler(Profiler): self._schedule = None self.profiler.schedule = torch.profiler.profiler._default_schedule_fn - def on_trace_ready(profiler): + def on_trace_ready(profiler: _PROFILER) -> None: if self.dirpath is not None: if self._export_to_chrome: handler = tensorboard_trace_handler( - self.dirpath, self._prepare_filename(action_name=action_name, extension="") + str(self.dirpath), self._prepare_filename(action_name=action_name, extension="") ) handler(profiler) @@ -436,6 +444,7 @@ class PyTorchProfiler(Profiler): path = os.path.join( self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack") ) + assert isinstance(profiler, torch.autograd.profiler.profile) profiler.export_stacks(path, metric=self._metric) else: rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") @@ -469,8 +478,12 @@ class PyTorchProfiler(Profiler): return self._stats_to_str(recorded_stats) def _create_profilers(self) -> None: + if self.profiler is not None: + return + if self._emit_nvtx: - self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) + if self._parent_profiler is None: + self._parent_profiler = torch.cuda.profiler.profile() self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) else: self._parent_profiler = None @@ -486,7 +499,13 @@ class PyTorchProfiler(Profiler): def _cache_functions_events(self) -> None: if self._emit_nvtx: return - self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events + + if _KINETO_AVAILABLE: + assert isinstance(self.profiler, torch.profiler.profile) + self.function_events = self.profiler.events() + else: + assert isinstance(self.profiler, torch.autograd.profiler.profile) + self.function_events = self.profiler.function_events def _delete_profilers(self) -> None: if self.profiler is not None: @@ -505,7 +524,7 @@ class PyTorchProfiler(Profiler): self._register.__exit__(None, None, None) self._register = None - def teardown(self, stage: str) -> None: + def teardown(self, stage: Optional[str]) -> None: self._delete_profilers() for k in list(self._recording_map):