Fix mypy errors attributed to `pytorch_lightning.profilers.pytorch` (#14405)

* remove toml ref
* fix conflicts
* small fix
* move assertion

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
Krishna Kalyan 2022-09-13 18:11:45 +02:00 committed by GitHub
parent c81a71c908
commit f68c0909fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 26 deletions

View File

@ -52,7 +52,6 @@ warn_no_return = "False"
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.data",

View File

@ -17,7 +17,7 @@ import logging
import os
from functools import lru_cache, partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union
import torch
from lightning_utilities.core.rank_zero import WarningCache
@ -42,7 +42,7 @@ if _KINETO_AVAILABLE:
log = logging.getLogger(__name__)
warning_cache = WarningCache()
_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx]
_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx]
class RegisterRecordFunction:
@ -111,13 +111,7 @@ class ScheduleWrapper:
self._schedule = schedule
self.reset()
def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name
def pre_step(self, current_action: str) -> None:
self._current_action = current_action
def reset(self):
def reset(self) -> None:
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
self._num_training_step = 0
self._num_validation_step = 0
@ -132,20 +126,30 @@ class ScheduleWrapper:
self._prev_schedule_action: Optional[ProfilerAction] = None
self._start_action_name: Optional[str] = None
def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name
def pre_step(self, current_action: str) -> None:
self._current_action = current_action
@property
def is_training(self):
def is_training(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("training_step")
@property
def is_validating(self):
def is_validating(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("validation_step")
@property
def is_testing(self):
def is_testing(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("test_step")
@property
def is_predicting(self):
def is_predicting(self) -> bool:
assert self._current_action is not None
return self._current_action.endswith("predict_step")
@property
@ -164,6 +168,7 @@ class ScheduleWrapper:
if self.is_training:
self._num_training_step += 1
elif self.is_validating:
assert self._start_action_name is not None
if self._start_action_name.endswith("on_fit_start"):
if self._num_training_step > 0:
self._num_validation_step += 1
@ -238,7 +243,7 @@ class PyTorchProfiler(Profiler):
record_module_names: bool = True,
**profiler_kwargs: Any,
) -> None:
"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
different operators inside your model - both on the CPU and GPU
@ -276,7 +281,7 @@ class PyTorchProfiler(Profiler):
record_module_names: Whether to add module names while recording autograd operation.
profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
\**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
Raises:
MisconfigurationException:
@ -298,7 +303,7 @@ class PyTorchProfiler(Profiler):
self.function_events: Optional["EventList"] = None
self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector
self._register: Optional[RegisterRecordFunction] = None
self._parent_profiler: Optional[_PROFILER] = None
self._parent_profiler: Optional[ContextManager] = None
self._recording_map: Dict[str, record_function] = {}
self._start_action_name: Optional[str] = None
self._schedule: Optional[ScheduleWrapper] = None
@ -317,7 +322,7 @@ class PyTorchProfiler(Profiler):
schedule = profiler_kwargs.get("schedule", None)
if schedule is not None:
if not isinstance(schedule, Callable):
if not callable(schedule):
raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}")
action = schedule(0)
if not isinstance(action, ProfilerAction):
@ -337,7 +342,9 @@ class PyTorchProfiler(Profiler):
self._profiler_kwargs["with_stack"] = with_stack
@property
def _total_steps(self) -> int:
def _total_steps(self) -> Union[int, float]:
assert self._schedule is not None
assert self._lightning_module is not None
trainer = self._lightning_module.trainer
if self._schedule.is_training:
return trainer.num_training_batches
@ -358,13 +365,13 @@ class PyTorchProfiler(Profiler):
@staticmethod
@lru_cache(1)
def _default_schedule() -> Optional[callable]:
def _default_schedule() -> Optional[Callable]:
if _KINETO_AVAILABLE:
# Those schedule defaults allow the profiling overhead to be negligible over training time.
return torch.profiler.schedule(wait=1, warmup=1, active=3)
def _default_activities(self) -> List["ProfilerActivity"]:
activities = []
activities: List["ProfilerActivity"] = []
if not _KINETO_AVAILABLE:
return activities
if self._profiler_kwargs.get("use_cpu", True):
@ -411,6 +418,7 @@ class PyTorchProfiler(Profiler):
return
if self.profiler is not None and any(action_name.endswith(func) for func in self.STEP_FUNCTIONS):
assert isinstance(self.profiler, torch.profiler.profile)
if self._schedule is not None:
self._schedule.pre_step(action_name)
@ -424,11 +432,11 @@ class PyTorchProfiler(Profiler):
self._schedule = None
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn
def on_trace_ready(profiler):
def on_trace_ready(profiler: _PROFILER) -> None:
if self.dirpath is not None:
if self._export_to_chrome:
handler = tensorboard_trace_handler(
self.dirpath, self._prepare_filename(action_name=action_name, extension="")
str(self.dirpath), self._prepare_filename(action_name=action_name, extension="")
)
handler(profiler)
@ -436,6 +444,7 @@ class PyTorchProfiler(Profiler):
path = os.path.join(
self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack")
)
assert isinstance(profiler, torch.autograd.profiler.profile)
profiler.export_stacks(path, metric=self._metric)
else:
rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")
@ -469,8 +478,12 @@ class PyTorchProfiler(Profiler):
return self._stats_to_str(recorded_stats)
def _create_profilers(self) -> None:
if self.profiler is not None:
return
if self._emit_nvtx:
self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile)
if self._parent_profiler is None:
self._parent_profiler = torch.cuda.profiler.profile()
self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx)
else:
self._parent_profiler = None
@ -486,7 +499,13 @@ class PyTorchProfiler(Profiler):
def _cache_functions_events(self) -> None:
if self._emit_nvtx:
return
self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events
if _KINETO_AVAILABLE:
assert isinstance(self.profiler, torch.profiler.profile)
self.function_events = self.profiler.events()
else:
assert isinstance(self.profiler, torch.autograd.profiler.profile)
self.function_events = self.profiler.function_events
def _delete_profilers(self) -> None:
if self.profiler is not None:
@ -505,7 +524,7 @@ class PyTorchProfiler(Profiler):
self._register.__exit__(None, None, None)
self._register = None
def teardown(self, stage: str) -> None:
def teardown(self, stage: Optional[str]) -> None:
self._delete_profilers()
for k in list(self._recording_map):