Resolve schedule step bug for PyTorch Profiler (#6674)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
217c12a4e7
commit
0ea8f39841
|
@ -107,6 +107,15 @@ class ScheduleWrapper:
|
|||
if not _KINETO_AVAILABLE:
|
||||
raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.")
|
||||
self._schedule = schedule
|
||||
self.reset()
|
||||
|
||||
def setup(self, start_action_name: str) -> None:
|
||||
self._start_action_name = start_action_name
|
||||
|
||||
def pre_step(self, current_action: str) -> None:
|
||||
self._current_action = current_action
|
||||
|
||||
def reset(self):
|
||||
self._num_training_step_and_backward = 0
|
||||
self._num_validation_step = 0
|
||||
self._num_test_step = 0
|
||||
|
@ -119,12 +128,6 @@ class ScheduleWrapper:
|
|||
self._current_action: Optional[str] = None
|
||||
self._start_action_name: Optional[str] = None
|
||||
|
||||
def setup(self, start_action_name: str) -> None:
|
||||
self._start_action_name = start_action_name
|
||||
|
||||
def pre_step(self, current_action: str) -> None:
|
||||
self._current_action = current_action
|
||||
|
||||
@property
|
||||
def num_step(self) -> int:
|
||||
if self._current_action == "training_step_and_backward":
|
||||
|
@ -142,8 +145,9 @@ class ScheduleWrapper:
|
|||
if self._current_action == "training_step_and_backward":
|
||||
self._num_training_step_and_backward += 1
|
||||
elif self._current_action == "validation_step":
|
||||
if self._start_action_name == "on_train_start" and self._num_training_step_and_backward > 0:
|
||||
self._num_validation_step += 1
|
||||
if self._start_action_name == "on_fit_start":
|
||||
if self._num_training_step_and_backward > 0:
|
||||
self._num_validation_step += 1
|
||||
else:
|
||||
self._num_validation_step += 1
|
||||
elif self._current_action == "test_step":
|
||||
|
@ -210,7 +214,7 @@ class PyTorchProfiler(BaseProfiler):
|
|||
"count",
|
||||
}
|
||||
START_RECORD_FUNCTIONS = {
|
||||
'on_train_start',
|
||||
'on_fit_start',
|
||||
'on_validation_start',
|
||||
'on_test_start',
|
||||
'on_predict_start',
|
||||
|
@ -289,8 +293,9 @@ class PyTorchProfiler(BaseProfiler):
|
|||
self._export_to_chrome = export_to_chrome
|
||||
self._row_limit = row_limit
|
||||
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
|
||||
self._record_functions_start = record_functions | self.START_RECORD_FUNCTIONS
|
||||
self._record_functions = record_functions | self.RECORD_FUNCTIONS
|
||||
self._user_record_functions = record_functions
|
||||
self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS
|
||||
self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS
|
||||
self._record_module_names = record_module_names
|
||||
self._profiler_kwargs = profiler_kwargs
|
||||
|
||||
|
@ -304,14 +309,14 @@ class PyTorchProfiler(BaseProfiler):
|
|||
self._schedule: Optional[ScheduleWrapper] = None
|
||||
|
||||
if _KINETO_AVAILABLE:
|
||||
self.__init_kineto__(profiler_kwargs)
|
||||
self._init_kineto(profiler_kwargs)
|
||||
|
||||
if self._sort_by_key not in self.AVAILABLE_SORT_KEYS:
|
||||
raise MisconfigurationException(
|
||||
f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
|
||||
)
|
||||
|
||||
def __init_kineto__(self, profiler_kwargs: Any):
|
||||
def _init_kineto(self, profiler_kwargs: Any) -> None:
|
||||
has_schedule = "schedule" in profiler_kwargs
|
||||
self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs
|
||||
|
||||
|
@ -362,7 +367,7 @@ class PyTorchProfiler(BaseProfiler):
|
|||
def _default_schedule() -> Optional[callable]:
|
||||
if _KINETO_AVAILABLE:
|
||||
# Those schedule defaults allow the profiling overhead to be negligible over training time.
|
||||
return torch.profiler.schedule(wait=1, warmup=1, active=2)
|
||||
return torch.profiler.schedule(wait=1, warmup=1, active=3)
|
||||
|
||||
def _default_activities(self) -> List['ProfilerActivity']:
|
||||
activities = []
|
||||
|
@ -374,10 +379,6 @@ class PyTorchProfiler(BaseProfiler):
|
|||
activities.append(ProfilerActivity.CUDA)
|
||||
return activities
|
||||
|
||||
@property
|
||||
def step_action_names(self) -> Set[str]:
|
||||
return self.STEP_FUNCTIONS | self._record_functions
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
if self.profiler is None and action_name in self._record_functions_start:
|
||||
|
||||
|
@ -411,9 +412,6 @@ class PyTorchProfiler(BaseProfiler):
|
|||
recording.__enter__()
|
||||
self._recording_map[action_name] = recording
|
||||
|
||||
if self._schedule is not None:
|
||||
self._schedule.pre_step(action_name)
|
||||
|
||||
def stop(self, action_name: str) -> None:
|
||||
if action_name in self._recording_map:
|
||||
self._recording_map[action_name].__exit__(None, None, None)
|
||||
|
@ -422,16 +420,14 @@ class PyTorchProfiler(BaseProfiler):
|
|||
if not _KINETO_AVAILABLE or self._emit_nvtx:
|
||||
return
|
||||
|
||||
if action_name in self.step_action_names:
|
||||
if self.profiler is not None and action_name in self.STEP_FUNCTIONS:
|
||||
if self._schedule is not None:
|
||||
self._schedule._current_action = action_name
|
||||
self._schedule.pre_step(action_name)
|
||||
|
||||
def on_trace_ready(profiler):
|
||||
filename = f"{action_name}_{self.local_rank}"
|
||||
|
||||
if self.dirpath is not None:
|
||||
if self._export_to_chrome:
|
||||
handler = tensorboard_trace_handler(self.dirpath, filename)
|
||||
handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension=""))
|
||||
handler(profiler)
|
||||
|
||||
if self._export_to_flame_graph:
|
||||
|
@ -442,6 +438,9 @@ class PyTorchProfiler(BaseProfiler):
|
|||
|
||||
if not self._has_on_trace_ready:
|
||||
self.profiler.on_trace_ready = on_trace_ready
|
||||
|
||||
if self._schedule is not None:
|
||||
self.profiler.step_num = self._schedule.num_step
|
||||
self.profiler.step()
|
||||
|
||||
def summary(self) -> str:
|
||||
|
@ -492,6 +491,9 @@ class PyTorchProfiler(BaseProfiler):
|
|||
self._cache_functions_events()
|
||||
self.profiler = None
|
||||
|
||||
if self._schedule is not None:
|
||||
self._schedule.reset()
|
||||
|
||||
if self._parent_profiler is not None:
|
||||
self._parent_profiler.__exit__(None, None, None)
|
||||
self._parent_profiler = None
|
||||
|
|
|
@ -57,5 +57,5 @@ class ProfilerConnector:
|
|||
def setup(self) -> None:
|
||||
trainer = self.trainer
|
||||
local_rank = trainer.local_rank if trainer.world_size > 1 else None
|
||||
trainer.profiler.lightning_module = proxy(trainer.lightning_module)
|
||||
trainer.profiler._lightning_module = proxy(trainer.lightning_module)
|
||||
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)
|
||||
|
|
|
@ -11,6 +11,7 @@ VERSIONS_LUT: Dict[str, Dict[str, Any]] = {
|
|||
"1.7.0": dict(torchvision="0.8.1", torchtext="0.8"),
|
||||
"1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"),
|
||||
"1.8.0": dict(torchvision="0.9.0", torchtext="0.9"),
|
||||
"1.8.1": dict(torchvision="0.9.0", torchtext="0.9"),
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue