Fix mypy errors attributed to `pytorch_lightning.profilers.pytorch` (#14405)
* remove toml ref * fix conflicts * small fix * move assertion Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
This commit is contained in:
parent
c81a71c908
commit
f68c0909fd
|
@ -52,7 +52,6 @@ warn_no_return = "False"
|
|||
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
|
||||
module = [
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.profilers.pytorch",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
"pytorch_lightning.utilities.data",
|
||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
|||
import os
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, ContextManager, Dict, List, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import WarningCache
|
||||
|
@ -42,7 +42,7 @@ if _KINETO_AVAILABLE:
|
|||
log = logging.getLogger(__name__)
|
||||
warning_cache = WarningCache()
|
||||
|
||||
_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx]
|
||||
_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx]
|
||||
|
||||
|
||||
class RegisterRecordFunction:
|
||||
|
@ -111,13 +111,7 @@ class ScheduleWrapper:
|
|||
self._schedule = schedule
|
||||
self.reset()
|
||||
|
||||
def setup(self, start_action_name: str) -> None:
|
||||
self._start_action_name = start_action_name
|
||||
|
||||
def pre_step(self, current_action: str) -> None:
|
||||
self._current_action = current_action
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
# handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise.
|
||||
self._num_training_step = 0
|
||||
self._num_validation_step = 0
|
||||
|
@ -132,20 +126,30 @@ class ScheduleWrapper:
|
|||
self._prev_schedule_action: Optional[ProfilerAction] = None
|
||||
self._start_action_name: Optional[str] = None
|
||||
|
||||
def setup(self, start_action_name: str) -> None:
|
||||
self._start_action_name = start_action_name
|
||||
|
||||
def pre_step(self, current_action: str) -> None:
|
||||
self._current_action = current_action
|
||||
|
||||
@property
|
||||
def is_training(self):
|
||||
def is_training(self) -> bool:
|
||||
assert self._current_action is not None
|
||||
return self._current_action.endswith("training_step")
|
||||
|
||||
@property
|
||||
def is_validating(self):
|
||||
def is_validating(self) -> bool:
|
||||
assert self._current_action is not None
|
||||
return self._current_action.endswith("validation_step")
|
||||
|
||||
@property
|
||||
def is_testing(self):
|
||||
def is_testing(self) -> bool:
|
||||
assert self._current_action is not None
|
||||
return self._current_action.endswith("test_step")
|
||||
|
||||
@property
|
||||
def is_predicting(self):
|
||||
def is_predicting(self) -> bool:
|
||||
assert self._current_action is not None
|
||||
return self._current_action.endswith("predict_step")
|
||||
|
||||
@property
|
||||
|
@ -164,6 +168,7 @@ class ScheduleWrapper:
|
|||
if self.is_training:
|
||||
self._num_training_step += 1
|
||||
elif self.is_validating:
|
||||
assert self._start_action_name is not None
|
||||
if self._start_action_name.endswith("on_fit_start"):
|
||||
if self._num_training_step > 0:
|
||||
self._num_validation_step += 1
|
||||
|
@ -238,7 +243,7 @@ class PyTorchProfiler(Profiler):
|
|||
record_module_names: bool = True,
|
||||
**profiler_kwargs: Any,
|
||||
) -> None:
|
||||
"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
|
||||
r"""This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of.
|
||||
|
||||
different operators inside your model - both on the CPU and GPU
|
||||
|
||||
|
@ -276,7 +281,7 @@ class PyTorchProfiler(Profiler):
|
|||
|
||||
record_module_names: Whether to add module names while recording autograd operation.
|
||||
|
||||
profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
|
||||
\**profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version
|
||||
|
||||
Raises:
|
||||
MisconfigurationException:
|
||||
|
@ -298,7 +303,7 @@ class PyTorchProfiler(Profiler):
|
|||
self.function_events: Optional["EventList"] = None
|
||||
self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector
|
||||
self._register: Optional[RegisterRecordFunction] = None
|
||||
self._parent_profiler: Optional[_PROFILER] = None
|
||||
self._parent_profiler: Optional[ContextManager] = None
|
||||
self._recording_map: Dict[str, record_function] = {}
|
||||
self._start_action_name: Optional[str] = None
|
||||
self._schedule: Optional[ScheduleWrapper] = None
|
||||
|
@ -317,7 +322,7 @@ class PyTorchProfiler(Profiler):
|
|||
|
||||
schedule = profiler_kwargs.get("schedule", None)
|
||||
if schedule is not None:
|
||||
if not isinstance(schedule, Callable):
|
||||
if not callable(schedule):
|
||||
raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}")
|
||||
action = schedule(0)
|
||||
if not isinstance(action, ProfilerAction):
|
||||
|
@ -337,7 +342,9 @@ class PyTorchProfiler(Profiler):
|
|||
self._profiler_kwargs["with_stack"] = with_stack
|
||||
|
||||
@property
|
||||
def _total_steps(self) -> int:
|
||||
def _total_steps(self) -> Union[int, float]:
|
||||
assert self._schedule is not None
|
||||
assert self._lightning_module is not None
|
||||
trainer = self._lightning_module.trainer
|
||||
if self._schedule.is_training:
|
||||
return trainer.num_training_batches
|
||||
|
@ -358,13 +365,13 @@ class PyTorchProfiler(Profiler):
|
|||
|
||||
@staticmethod
|
||||
@lru_cache(1)
|
||||
def _default_schedule() -> Optional[callable]:
|
||||
def _default_schedule() -> Optional[Callable]:
|
||||
if _KINETO_AVAILABLE:
|
||||
# Those schedule defaults allow the profiling overhead to be negligible over training time.
|
||||
return torch.profiler.schedule(wait=1, warmup=1, active=3)
|
||||
|
||||
def _default_activities(self) -> List["ProfilerActivity"]:
|
||||
activities = []
|
||||
activities: List["ProfilerActivity"] = []
|
||||
if not _KINETO_AVAILABLE:
|
||||
return activities
|
||||
if self._profiler_kwargs.get("use_cpu", True):
|
||||
|
@ -411,6 +418,7 @@ class PyTorchProfiler(Profiler):
|
|||
return
|
||||
|
||||
if self.profiler is not None and any(action_name.endswith(func) for func in self.STEP_FUNCTIONS):
|
||||
assert isinstance(self.profiler, torch.profiler.profile)
|
||||
if self._schedule is not None:
|
||||
self._schedule.pre_step(action_name)
|
||||
|
||||
|
@ -424,11 +432,11 @@ class PyTorchProfiler(Profiler):
|
|||
self._schedule = None
|
||||
self.profiler.schedule = torch.profiler.profiler._default_schedule_fn
|
||||
|
||||
def on_trace_ready(profiler):
|
||||
def on_trace_ready(profiler: _PROFILER) -> None:
|
||||
if self.dirpath is not None:
|
||||
if self._export_to_chrome:
|
||||
handler = tensorboard_trace_handler(
|
||||
self.dirpath, self._prepare_filename(action_name=action_name, extension="")
|
||||
str(self.dirpath), self._prepare_filename(action_name=action_name, extension="")
|
||||
)
|
||||
handler(profiler)
|
||||
|
||||
|
@ -436,6 +444,7 @@ class PyTorchProfiler(Profiler):
|
|||
path = os.path.join(
|
||||
self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack")
|
||||
)
|
||||
assert isinstance(profiler, torch.autograd.profiler.profile)
|
||||
profiler.export_stacks(path, metric=self._metric)
|
||||
else:
|
||||
rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None")
|
||||
|
@ -469,8 +478,12 @@ class PyTorchProfiler(Profiler):
|
|||
return self._stats_to_str(recorded_stats)
|
||||
|
||||
def _create_profilers(self) -> None:
|
||||
if self.profiler is not None:
|
||||
return
|
||||
|
||||
if self._emit_nvtx:
|
||||
self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile)
|
||||
if self._parent_profiler is None:
|
||||
self._parent_profiler = torch.cuda.profiler.profile()
|
||||
self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx)
|
||||
else:
|
||||
self._parent_profiler = None
|
||||
|
@ -486,7 +499,13 @@ class PyTorchProfiler(Profiler):
|
|||
def _cache_functions_events(self) -> None:
|
||||
if self._emit_nvtx:
|
||||
return
|
||||
self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events
|
||||
|
||||
if _KINETO_AVAILABLE:
|
||||
assert isinstance(self.profiler, torch.profiler.profile)
|
||||
self.function_events = self.profiler.events()
|
||||
else:
|
||||
assert isinstance(self.profiler, torch.autograd.profiler.profile)
|
||||
self.function_events = self.profiler.function_events
|
||||
|
||||
def _delete_profilers(self) -> None:
|
||||
if self.profiler is not None:
|
||||
|
@ -505,7 +524,7 @@ class PyTorchProfiler(Profiler):
|
|||
self._register.__exit__(None, None, None)
|
||||
self._register = None
|
||||
|
||||
def teardown(self, stage: str) -> None:
|
||||
def teardown(self, stage: Optional[str]) -> None:
|
||||
self._delete_profilers()
|
||||
|
||||
for k in list(self._recording_map):
|
||||
|
|
Loading…
Reference in New Issue