From d9938da8a4aeda9ba08e97476c581948e82951b7 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 24 Feb 2022 17:03:16 +0530 Subject: [PATCH] add properties to check for trainer state in pytorch profier (#12063) --- pytorch_lightning/profiler/pytorch.py | 56 +++++++++++++++++---------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 7afb962abb..b4fdcc8131 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -131,41 +131,57 @@ class ScheduleWrapper: self._prev_schedule_action: Optional[ProfilerAction] = None self._start_action_name: Optional[str] = None + @property + def is_training(self): + return self._current_action.endswith("training_step") + + @property + def is_validating(self): + return self._current_action.endswith("validation_step") + + @property + def is_testing(self): + return self._current_action.endswith("test_step") + + @property + def is_predicting(self): + return self._current_action.endswith("predict_step") + @property def num_step(self) -> int: - if self._current_action.endswith("training_step"): + if self.is_training: return self._num_training_step - if self._current_action.endswith("validation_step"): + if self.is_validating: return self._num_validation_step - if self._current_action.endswith("test_step"): + if self.is_testing: return self._num_test_step - if self._current_action.endswith("predict_step"): + if self.is_predicting: return self._num_predict_step return 0 def _step(self) -> None: - if self._current_action.endswith("training_step"): + if self.is_training: self._num_training_step += 1 - elif self._current_action.endswith("validation_step"): + elif self.is_validating: if self._start_action_name.endswith("on_fit_start"): if self._num_training_step > 0: self._num_validation_step += 1 else: self._num_validation_step += 1 - elif self._current_action.endswith("test_step"): + elif self.is_testing: self._num_test_step += 1 - elif self._current_action.endswith("predict_step"): + elif self.is_predicting: self._num_predict_step += 1 @property def has_finished(self) -> bool: - if self._current_action.endswith("training_step"): + if self.is_training: return self._training_step_reached_end - if self._current_action.endswith("validation_step"): + if self.is_validating: return self._validation_step_reached_end - if self._current_action.endswith("test_step"): + if self.is_testing: return self._test_step_reached_end - if self._current_action.endswith("predict_step"): + if self.is_predicting: return self._predict_step_reached_end return False @@ -182,13 +198,13 @@ class ScheduleWrapper: # and the action is still WARMUP in train and pytorch will recognize this as error. action = ProfilerAction.RECORD if action == ProfilerAction.RECORD_AND_SAVE: - if self._current_action.endswith("training_step"): + if self.is_training: self._training_step_reached_end = True - elif self._current_action.endswith("validation_step"): + elif self.is_validating: self._validation_step_reached_end = True - elif self._current_action.endswith("test_step"): + elif self.is_testing: self._test_step_reached_end = True - elif self._current_action.endswith("predict_step"): + elif self.is_predicting: self._predict_step_reached_end = True self._prev_schedule_action = action return action @@ -322,13 +338,13 @@ class PyTorchProfiler(BaseProfiler): @property def _total_steps(self) -> int: trainer = self._lightning_module.trainer - if self._schedule._current_action.endswith("training_step"): + if self._schedule.is_training: return trainer.num_training_batches - if self._schedule._current_action.endswith("validation_step"): + if self._schedule.is_validating: return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches) - if self._schedule._current_action.endswith("test_step"): + if self._schedule.is_testing: return sum(trainer.num_test_batches) - if self._schedule._current_action.endswith("predict_step"): + if self._schedule.is_predicting: return sum(trainer.num_predict_batches) def _should_override_schedule(self) -> bool: