From 0ea8f3984141118003460d850ce2810581ef3f46 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 25 Mar 2021 16:03:06 +0000 Subject: [PATCH] Resolve schedule step bug for PyTorch Profiler (#6674) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/profiler/pytorch.py | 54 ++++++++++--------- .../trainer/connectors/profiler_connector.py | 2 +- requirements/adjust_versions.py | 1 + 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 73abc1baf9..fa2c2917f9 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -107,6 +107,15 @@ class ScheduleWrapper: if not _KINETO_AVAILABLE: raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") self._schedule = schedule + self.reset() + + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + + def reset(self): self._num_training_step_and_backward = 0 self._num_validation_step = 0 self._num_test_step = 0 @@ -119,12 +128,6 @@ class ScheduleWrapper: self._current_action: Optional[str] = None self._start_action_name: Optional[str] = None - def setup(self, start_action_name: str) -> None: - self._start_action_name = start_action_name - - def pre_step(self, current_action: str) -> None: - self._current_action = current_action - @property def num_step(self) -> int: if self._current_action == "training_step_and_backward": @@ -142,8 +145,9 @@ class ScheduleWrapper: if self._current_action == "training_step_and_backward": self._num_training_step_and_backward += 1 elif self._current_action == "validation_step": - if self._start_action_name == "on_train_start" and self._num_training_step_and_backward > 0: - self._num_validation_step += 1 + if self._start_action_name == "on_fit_start": + if self._num_training_step_and_backward > 0: + self._num_validation_step += 1 else: self._num_validation_step += 1 elif self._current_action == "test_step": @@ -210,7 +214,7 @@ class PyTorchProfiler(BaseProfiler): "count", } START_RECORD_FUNCTIONS = { - 'on_train_start', + 'on_fit_start', 'on_validation_start', 'on_test_start', 'on_predict_start', @@ -289,8 +293,9 @@ class PyTorchProfiler(BaseProfiler): self._export_to_chrome = export_to_chrome self._row_limit = row_limit self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" - self._record_functions_start = record_functions | self.START_RECORD_FUNCTIONS - self._record_functions = record_functions | self.RECORD_FUNCTIONS + self._user_record_functions = record_functions + self._record_functions_start = self._user_record_functions | self.START_RECORD_FUNCTIONS + self._record_functions = self._user_record_functions | self.RECORD_FUNCTIONS self._record_module_names = record_module_names self._profiler_kwargs = profiler_kwargs @@ -304,14 +309,14 @@ class PyTorchProfiler(BaseProfiler): self._schedule: Optional[ScheduleWrapper] = None if _KINETO_AVAILABLE: - self.__init_kineto__(profiler_kwargs) + self._init_kineto(profiler_kwargs) if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: raise MisconfigurationException( f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " ) - def __init_kineto__(self, profiler_kwargs: Any): + def _init_kineto(self, profiler_kwargs: Any) -> None: has_schedule = "schedule" in profiler_kwargs self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs @@ -362,7 +367,7 @@ class PyTorchProfiler(BaseProfiler): def _default_schedule() -> Optional[callable]: if _KINETO_AVAILABLE: # Those schedule defaults allow the profiling overhead to be negligible over training time. - return torch.profiler.schedule(wait=1, warmup=1, active=2) + return torch.profiler.schedule(wait=1, warmup=1, active=3) def _default_activities(self) -> List['ProfilerActivity']: activities = [] @@ -374,10 +379,6 @@ class PyTorchProfiler(BaseProfiler): activities.append(ProfilerActivity.CUDA) return activities - @property - def step_action_names(self) -> Set[str]: - return self.STEP_FUNCTIONS | self._record_functions - def start(self, action_name: str) -> None: if self.profiler is None and action_name in self._record_functions_start: @@ -411,9 +412,6 @@ class PyTorchProfiler(BaseProfiler): recording.__enter__() self._recording_map[action_name] = recording - if self._schedule is not None: - self._schedule.pre_step(action_name) - def stop(self, action_name: str) -> None: if action_name in self._recording_map: self._recording_map[action_name].__exit__(None, None, None) @@ -422,16 +420,14 @@ class PyTorchProfiler(BaseProfiler): if not _KINETO_AVAILABLE or self._emit_nvtx: return - if action_name in self.step_action_names: + if self.profiler is not None and action_name in self.STEP_FUNCTIONS: if self._schedule is not None: - self._schedule._current_action = action_name + self._schedule.pre_step(action_name) def on_trace_ready(profiler): - filename = f"{action_name}_{self.local_rank}" - if self.dirpath is not None: if self._export_to_chrome: - handler = tensorboard_trace_handler(self.dirpath, filename) + handler = tensorboard_trace_handler(self.dirpath, self._prepare_filename(extension="")) handler(profiler) if self._export_to_flame_graph: @@ -442,6 +438,9 @@ class PyTorchProfiler(BaseProfiler): if not self._has_on_trace_ready: self.profiler.on_trace_ready = on_trace_ready + + if self._schedule is not None: + self.profiler.step_num = self._schedule.num_step self.profiler.step() def summary(self) -> str: @@ -492,6 +491,9 @@ class PyTorchProfiler(BaseProfiler): self._cache_functions_events() self.profiler = None + if self._schedule is not None: + self._schedule.reset() + if self._parent_profiler is not None: self._parent_profiler.__exit__(None, None, None) self._parent_profiler = None diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 191e871146..fa1002d70a 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -57,5 +57,5 @@ class ProfilerConnector: def setup(self) -> None: trainer = self.trainer local_rank = trainer.local_rank if trainer.world_size > 1 else None - trainer.profiler.lightning_module = proxy(trainer.lightning_module) + trainer.profiler._lightning_module = proxy(trainer.lightning_module) trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir) diff --git a/requirements/adjust_versions.py b/requirements/adjust_versions.py index c1499cd4ea..d0dfbc59e2 100644 --- a/requirements/adjust_versions.py +++ b/requirements/adjust_versions.py @@ -11,6 +11,7 @@ VERSIONS_LUT: Dict[str, Dict[str, Any]] = { "1.7.0": dict(torchvision="0.8.1", torchtext="0.8"), "1.7.1": dict(torchvision="0.8.2", torchtext="0.8.1"), "1.8.0": dict(torchvision="0.9.0", torchtext="0.9"), + "1.8.1": dict(torchvision="0.9.0", torchtext="0.9"), }