Resolve schedule step bug for PyTorch Profiler (#6674)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
thomas chaton 2021-03-25 16:03:06 +00:00 committed by GitHub
parent 217c12a4e7
commit 0ea8f39841
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 27 deletions

View File

@ -107,6 +107,15 @@ class ScheduleWrapper:
if not _KINETO_AVAILABLE:
raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.")
self._schedule = schedule
self.reset()
def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name
def pre_step(self, current_action: str) -> None:
self._current_action = current_action
def reset(self):
self._num_training_step_and_backward = 0
self._num_validation_step = 0
self._num_test_step = 0
@ -119,12 +128,6 @@ class ScheduleWrapper:
self._current_action: Optional[str] = None
self._start_action_name: Optional[str] = None
def setup(self, start_action_name: str) -> None:
self._start_action_name = start_action_name
def pre_step(self, current_action: str) -> None:
self._current_action = current_action
@property
def num_step(self) -> int:
if self._current_action == "training_step_and_backward":
@ -142,8 +145,9 @@ class ScheduleWrapper:
if self._current_action == "training_step_and_backward":
self._num_training_step_and_backward += 1
elif self._current_action == "validation_step":
if self._start_action_name == "on_train_start" and self._num_training_step_and_backward > 0:
self._num_validation_step += 1
if self._start_action_name == "on_fit_start":
if self._num_training_step_and_backward > 0:
self._num_validation_step += 1
else:
self._num_validation_step += 1
elif self._current_action == "test_step":
@ -210,7 +214,7 @@ class PyTorchProfiler(BaseProfiler):
"count",
}
START_RECORD_FUNCTIONS = {
'on_train_start',
'on_fit_start',
'on_validation_start',
'on_test_start',
'on_predict_start',
@ -289,8 +293,9 @@ class PyTorchProfiler(BaseProfiler):
self._export_to_chrome = export_to_chrome
self._row_limit = row_limit
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
self._record_functions_start = record_functions | self.START_RECORD_FUNCTIONS
self._record_functions = record_functions | self.RECORD_FUNCTIONS
self._user_record_functions = record_functions
self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS
self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS
self._record_module_names = record_module_names
self._profiler_kwargs = profiler_kwargs
@ -304,14 +309,14 @@ class PyTorchProfiler(BaseProfiler):
self._schedule: Optional[ScheduleWrapper] = None
if _KINETO_AVAILABLE:
self.__init_kineto__(profiler_kwargs)
self._init_kineto(profiler_kwargs)
if self._sort_by_key not in self.AVAILABLE_SORT_KEYS:
raise MisconfigurationException(
f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
)
def __init_kineto__(self, profiler_kwargs: Any):
def _init_kineto(self, profiler_kwargs: Any) -> None:
has_schedule = "schedule" in profiler_kwargs
self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs
@ -362,7 +367,7 @@ class PyTorchProfiler(BaseProfiler):
def _default_schedule() -> Optional[callable]:
if _KINETO_AVAILABLE:
# Those schedule defaults allow the profiling overhead to be negligible over training time.
return torch.profiler.schedule(wait=1, warmup=1, active=2)
return torch.profiler.schedule(wait=1, warmup=1, active=3)
def _default_activities(self) -> List['ProfilerActivity']:
activities = []
@ -374,10 +379,6 @@ class PyTorchProfiler(BaseProfiler):
activities.append(ProfilerActivity.CUDA)
return activities
@property
def step_action_names(self) -> Set[str]:
return self.STEP_FUNCTIONS | self._record_functions
def start(self, action_name: str) -> None:
if self.profiler is None and action_name in self._record_functions_start:
@ -411,9 +412,6 @@ class PyTorchProfiler(BaseProfiler):
recording.__enter__()
self._recording_map[action_name] = recording
if self._schedule is not None:
self._schedule.pre_step(action_name)
def stop(self, action_name: str) -> None:
if action_name in self._recording_map:
self._recording_map[action_name].__exit__(None, None, None)
@ -422,16 +420,14 @@ class PyTorchProfiler(BaseProfiler):
if not _KINETO_AVAILABLE or self._emit_nvtx:
return
if action_name in self.step_action_names:
if self.profiler is not None and action_name in self.STEP_FUNCTIONS:
if self._schedule is not None:
self._schedule._current_action = action_name
self._schedule.pre_step(action_name)
def on_trace_ready(profiler):
filename = f"{action_name}_{self.local_rank}"
if self.dirpath is not None:
if self._export_to_chrome:
handler = tensorboard_trace_handler(self.dirpath, filename)
handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension=""))
handler(profiler)
if self._export_to_flame_graph:
@ -442,6 +438,9 @@ class PyTorchProfiler(BaseProfiler):
if not self._has_on_trace_ready:
self.profiler.on_trace_ready = on_trace_ready
if self._schedule is not None:
self.profiler.step_num = self._schedule.num_step
self.profiler.step()
def summary(self) -> str:
@ -492,6 +491,9 @@ class PyTorchProfiler(BaseProfiler):
self._cache_functions_events()
self.profiler = None
if self._schedule is not None:
self._schedule.reset()
if self._parent_profiler is not None:
self._parent_profiler.__exit__(None, None, None)
self._parent_profiler = None

View File

@ -57,5 +57,5 @@ class ProfilerConnector:
def setup(self) -> None:
trainer = self.trainer
local_rank = trainer.local_rank if trainer.world_size > 1 else None
trainer.profiler.lightning_module = proxy(trainer.lightning_module)
trainer.profiler._lightning_module = proxy(trainer.lightning_module)
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)

View File

@ -11,6 +11,7 @@ VERSIONS_LUT: Dict[str, Dict[str, Any]] = {
"1.7.0": dict(torchvision="0.8.1", torchtext="0.8"),
"1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"),
"1.8.0": dict(torchvision="0.9.0", torchtext="0.9"),
"1.8.1": dict(torchvision="0.9.0", torchtext="0.9"),
}