add properties to check for trainer state in pytorch profier (#12063)
This commit is contained in:
parent
1fa0639bca
commit
d9938da8a4
|
@ -131,41 +131,57 @@ class ScheduleWrapper:
|
|||
self._prev_schedule_action: Optional[ProfilerAction] = None
|
||||
self._start_action_name: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_training(self):
|
||||
return self._current_action.endswith("training_step")
|
||||
|
||||
@property
|
||||
def is_validating(self):
|
||||
return self._current_action.endswith("validation_step")
|
||||
|
||||
@property
|
||||
def is_testing(self):
|
||||
return self._current_action.endswith("test_step")
|
||||
|
||||
@property
|
||||
def is_predicting(self):
|
||||
return self._current_action.endswith("predict_step")
|
||||
|
||||
@property
|
||||
def num_step(self) -> int:
|
||||
if self._current_action.endswith("training_step"):
|
||||
if self.is_training:
|
||||
return self._num_training_step
|
||||
if self._current_action.endswith("validation_step"):
|
||||
if self.is_validating:
|
||||
return self._num_validation_step
|
||||
if self._current_action.endswith("test_step"):
|
||||
if self.is_testing:
|
||||
return self._num_test_step
|
||||
if self._current_action.endswith("predict_step"):
|
||||
if self.is_predicting:
|
||||
return self._num_predict_step
|
||||
return 0
|
||||
|
||||
def _step(self) -> None:
|
||||
if self._current_action.endswith("training_step"):
|
||||
if self.is_training:
|
||||
self._num_training_step += 1
|
||||
elif self._current_action.endswith("validation_step"):
|
||||
elif self.is_validating:
|
||||
if self._start_action_name.endswith("on_fit_start"):
|
||||
if self._num_training_step > 0:
|
||||
self._num_validation_step += 1
|
||||
else:
|
||||
self._num_validation_step += 1
|
||||
elif self._current_action.endswith("test_step"):
|
||||
elif self.is_testing:
|
||||
self._num_test_step += 1
|
||||
elif self._current_action.endswith("predict_step"):
|
||||
elif self.is_predicting:
|
||||
self._num_predict_step += 1
|
||||
|
||||
@property
|
||||
def has_finished(self) -> bool:
|
||||
if self._current_action.endswith("training_step"):
|
||||
if self.is_training:
|
||||
return self._training_step_reached_end
|
||||
if self._current_action.endswith("validation_step"):
|
||||
if self.is_validating:
|
||||
return self._validation_step_reached_end
|
||||
if self._current_action.endswith("test_step"):
|
||||
if self.is_testing:
|
||||
return self._test_step_reached_end
|
||||
if self._current_action.endswith("predict_step"):
|
||||
if self.is_predicting:
|
||||
return self._predict_step_reached_end
|
||||
return False
|
||||
|
||||
|
@ -182,13 +198,13 @@ class ScheduleWrapper:
|
|||
# and the action is still WARMUP in train and pytorch will recognize this as error.
|
||||
action = ProfilerAction.RECORD
|
||||
if action == ProfilerAction.RECORD_AND_SAVE:
|
||||
if self._current_action.endswith("training_step"):
|
||||
if self.is_training:
|
||||
self._training_step_reached_end = True
|
||||
elif self._current_action.endswith("validation_step"):
|
||||
elif self.is_validating:
|
||||
self._validation_step_reached_end = True
|
||||
elif self._current_action.endswith("test_step"):
|
||||
elif self.is_testing:
|
||||
self._test_step_reached_end = True
|
||||
elif self._current_action.endswith("predict_step"):
|
||||
elif self.is_predicting:
|
||||
self._predict_step_reached_end = True
|
||||
self._prev_schedule_action = action
|
||||
return action
|
||||
|
@ -322,13 +338,13 @@ class PyTorchProfiler(BaseProfiler):
|
|||
@property
|
||||
def _total_steps(self) -> int:
|
||||
trainer = self._lightning_module.trainer
|
||||
if self._schedule._current_action.endswith("training_step"):
|
||||
if self._schedule.is_training:
|
||||
return trainer.num_training_batches
|
||||
if self._schedule._current_action.endswith("validation_step"):
|
||||
if self._schedule.is_validating:
|
||||
return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches)
|
||||
if self._schedule._current_action.endswith("test_step"):
|
||||
if self._schedule.is_testing:
|
||||
return sum(trainer.num_test_batches)
|
||||
if self._schedule._current_action.endswith("predict_step"):
|
||||
if self._schedule.is_predicting:
|
||||
return sum(trainer.num_predict_batches)
|
||||
|
||||
def _should_override_schedule(self) -> bool:
|
||||
|
|
Loading…
Reference in New Issue