add properties to check for trainer state in pytorch profier (#12063)

This commit is contained in:
Rohit Gupta 2022-02-24 17:03:16 +05:30 committed by GitHub
parent 1fa0639bca
commit d9938da8a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 20 deletions

View File

@ -131,41 +131,57 @@ class ScheduleWrapper:
self._prev_schedule_action: Optional[ProfilerAction] = None self._prev_schedule_action: Optional[ProfilerAction] = None
self._start_action_name: Optional[str] = None self._start_action_name: Optional[str] = None
@property
def is_training(self):
return self._current_action.endswith("training_step")
@property
def is_validating(self):
return self._current_action.endswith("validation_step")
@property
def is_testing(self):
return self._current_action.endswith("test_step")
@property
def is_predicting(self):
return self._current_action.endswith("predict_step")
@property @property
def num_step(self) -> int: def num_step(self) -> int:
if self._current_action.endswith("training_step"): if self.is_training:
return self._num_training_step return self._num_training_step
if self._current_action.endswith("validation_step"): if self.is_validating:
return self._num_validation_step return self._num_validation_step
if self._current_action.endswith("test_step"): if self.is_testing:
return self._num_test_step return self._num_test_step
if self._current_action.endswith("predict_step"): if self.is_predicting:
return self._num_predict_step return self._num_predict_step
return 0 return 0
def _step(self) -> None: def _step(self) -> None:
if self._current_action.endswith("training_step"): if self.is_training:
self._num_training_step += 1 self._num_training_step += 1
elif self._current_action.endswith("validation_step"): elif self.is_validating:
if self._start_action_name.endswith("on_fit_start"): if self._start_action_name.endswith("on_fit_start"):
if self._num_training_step > 0: if self._num_training_step > 0:
self._num_validation_step += 1 self._num_validation_step += 1
else: else:
self._num_validation_step += 1 self._num_validation_step += 1
elif self._current_action.endswith("test_step"): elif self.is_testing:
self._num_test_step += 1 self._num_test_step += 1
elif self._current_action.endswith("predict_step"): elif self.is_predicting:
self._num_predict_step += 1 self._num_predict_step += 1
@property @property
def has_finished(self) -> bool: def has_finished(self) -> bool:
if self._current_action.endswith("training_step"): if self.is_training:
return self._training_step_reached_end return self._training_step_reached_end
if self._current_action.endswith("validation_step"): if self.is_validating:
return self._validation_step_reached_end return self._validation_step_reached_end
if self._current_action.endswith("test_step"): if self.is_testing:
return self._test_step_reached_end return self._test_step_reached_end
if self._current_action.endswith("predict_step"): if self.is_predicting:
return self._predict_step_reached_end return self._predict_step_reached_end
return False return False
@ -182,13 +198,13 @@ class ScheduleWrapper:
# and the action is still WARMUP in train and pytorch will recognize this as error. # and the action is still WARMUP in train and pytorch will recognize this as error.
action = ProfilerAction.RECORD action = ProfilerAction.RECORD
if action == ProfilerAction.RECORD_AND_SAVE: if action == ProfilerAction.RECORD_AND_SAVE:
if self._current_action.endswith("training_step"): if self.is_training:
self._training_step_reached_end = True self._training_step_reached_end = True
elif self._current_action.endswith("validation_step"): elif self.is_validating:
self._validation_step_reached_end = True self._validation_step_reached_end = True
elif self._current_action.endswith("test_step"): elif self.is_testing:
self._test_step_reached_end = True self._test_step_reached_end = True
elif self._current_action.endswith("predict_step"): elif self.is_predicting:
self._predict_step_reached_end = True self._predict_step_reached_end = True
self._prev_schedule_action = action self._prev_schedule_action = action
return action return action
@ -322,13 +338,13 @@ class PyTorchProfiler(BaseProfiler):
@property @property
def _total_steps(self) -> int: def _total_steps(self) -> int:
trainer = self._lightning_module.trainer trainer = self._lightning_module.trainer
if self._schedule._current_action.endswith("training_step"): if self._schedule.is_training:
return trainer.num_training_batches return trainer.num_training_batches
if self._schedule._current_action.endswith("validation_step"): if self._schedule.is_validating:
return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches) return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches)
if self._schedule._current_action.endswith("test_step"): if self._schedule.is_testing:
return sum(trainer.num_test_batches) return sum(trainer.num_test_batches)
if self._schedule._current_action.endswith("predict_step"): if self._schedule.is_predicting:
return sum(trainer.num_predict_batches) return sum(trainer.num_predict_batches)
def _should_override_schedule(self) -> bool: def _should_override_schedule(self) -> bool: