rename _call_ttp_hook to _call_strategy_hook (#11150)
This commit is contained in:
parent
a3e2ef2be0
commit
f95976d602
|
@ -89,8 +89,8 @@ class YieldLoop(OptimizerLoop):
|
|||
self.trainer.training_type_plugin.post_training_step()
|
||||
|
||||
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
|
||||
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
|
||||
training_step_output = ttp_output if model_output is None else model_output
|
||||
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
|
||||
training_step_output = strategy_output if model_output is None else model_output
|
||||
|
||||
# The closure result takes care of properly detaching the loss for logging and peforms
|
||||
# some additional checks that the output format is correct.
|
||||
|
|
|
@ -200,11 +200,11 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
if self.trainer.testing:
|
||||
self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
|
||||
self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
|
||||
self.trainer._call_ttp_hook("on_test_start", *args, **kwargs)
|
||||
self.trainer._call_strategy_hook("on_test_start", *args, **kwargs)
|
||||
else:
|
||||
self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
|
||||
self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
|
||||
self.trainer._call_ttp_hook("on_validation_start", *args, **kwargs)
|
||||
self.trainer._call_strategy_hook("on_validation_start", *args, **kwargs)
|
||||
|
||||
def _on_evaluation_model_eval(self) -> None:
|
||||
"""Sets model to eval mode."""
|
||||
|
@ -225,11 +225,11 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
if self.trainer.testing:
|
||||
self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
|
||||
self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
|
||||
self.trainer._call_ttp_hook("on_test_end", *args, **kwargs)
|
||||
self.trainer._call_strategy_hook("on_test_end", *args, **kwargs)
|
||||
else:
|
||||
self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
|
||||
self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
|
||||
self.trainer._call_ttp_hook("on_validation_end", *args, **kwargs)
|
||||
self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs)
|
||||
|
||||
# reset the logger connector state
|
||||
self.trainer.logger_connector.reset_results()
|
||||
|
|
|
@ -114,7 +114,7 @@ class PredictionLoop(DataLoaderLoop):
|
|||
# hook
|
||||
self.trainer._call_callback_hooks("on_predict_start")
|
||||
self.trainer._call_lightning_module_hook("on_predict_start")
|
||||
self.trainer._call_ttp_hook("on_predict_start")
|
||||
self.trainer._call_strategy_hook("on_predict_start")
|
||||
|
||||
self.trainer._call_callback_hooks("on_predict_epoch_start")
|
||||
self.trainer._call_lightning_module_hook("on_predict_epoch_start")
|
||||
|
@ -142,7 +142,7 @@ class PredictionLoop(DataLoaderLoop):
|
|||
# hook
|
||||
self.trainer._call_callback_hooks("on_predict_end")
|
||||
self.trainer._call_lightning_module_hook("on_predict_end")
|
||||
self.trainer._call_ttp_hook("on_predict_end")
|
||||
self.trainer._call_strategy_hook("on_predict_end")
|
||||
|
||||
def _on_predict_model_eval(self) -> None:
|
||||
"""Calls ``on_predict_model_eval`` hook."""
|
||||
|
|
|
@ -112,7 +112,7 @@ class EvaluationEpochLoop(Loop):
|
|||
raise StopIteration
|
||||
|
||||
if not data_fetcher.store_on_device:
|
||||
batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0))
|
||||
batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0))
|
||||
|
||||
self.batch_progress.increment_ready()
|
||||
|
||||
|
@ -222,9 +222,9 @@ class EvaluationEpochLoop(Loop):
|
|||
the outputs of the step
|
||||
"""
|
||||
if self.trainer.testing:
|
||||
output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
|
||||
output = self.trainer._call_strategy_hook("test_step", *kwargs.values())
|
||||
else:
|
||||
output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
|
||||
output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
|
||||
|
||||
return output
|
||||
|
||||
|
@ -232,8 +232,8 @@ class EvaluationEpochLoop(Loop):
|
|||
"""Calls the `{validation/test}_step_end` hook."""
|
||||
hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
|
||||
model_output = self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs)
|
||||
ttp_output = self.trainer._call_ttp_hook(hook_name, *args, **kwargs)
|
||||
output = ttp_output if model_output is None else model_output
|
||||
strategy_output = self.trainer._call_strategy_hook(hook_name, *args, **kwargs)
|
||||
output = strategy_output if model_output is None else model_output
|
||||
return output
|
||||
|
||||
def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
|
||||
|
|
|
@ -96,7 +96,7 @@ class PredictionEpochLoop(Loop):
|
|||
if batch is None:
|
||||
raise StopIteration
|
||||
|
||||
batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=dataloader_idx)
|
||||
batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=dataloader_idx)
|
||||
|
||||
self.batch_progress.increment_ready()
|
||||
|
||||
|
@ -128,7 +128,7 @@ class PredictionEpochLoop(Loop):
|
|||
|
||||
self.batch_progress.increment_started()
|
||||
|
||||
predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
|
||||
predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
|
||||
|
||||
self.batch_progress.increment_processed()
|
||||
|
||||
|
|
|
@ -156,7 +156,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
|
||||
|
||||
if not data_fetcher.store_on_device:
|
||||
batch = self.trainer._call_ttp_hook("batch_to_device", batch)
|
||||
batch = self.trainer._call_strategy_hook("batch_to_device", batch)
|
||||
|
||||
self.batch_progress.increment_ready()
|
||||
|
||||
|
@ -182,7 +182,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
response = self.trainer._call_lightning_module_hook(
|
||||
"on_train_batch_start", batch, batch_idx, **extra_kwargs
|
||||
)
|
||||
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
||||
self.trainer._call_strategy_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
|
||||
if response == -1:
|
||||
self.batch_progress.increment_processed()
|
||||
raise StopIteration
|
||||
|
|
|
@ -195,7 +195,7 @@ class FitLoop(Loop):
|
|||
self._results.to(device=self.trainer.lightning_module.device)
|
||||
self.trainer._call_callback_hooks("on_train_start")
|
||||
self.trainer._call_lightning_module_hook("on_train_start")
|
||||
self.trainer._call_ttp_hook("on_train_start")
|
||||
self.trainer._call_strategy_hook("on_train_start")
|
||||
|
||||
def on_advance_start(self) -> None: # type: ignore[override]
|
||||
"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
|
||||
|
@ -252,7 +252,7 @@ class FitLoop(Loop):
|
|||
# hook
|
||||
self.trainer._call_callback_hooks("on_train_end")
|
||||
self.trainer._call_lightning_module_hook("on_train_end")
|
||||
self.trainer._call_ttp_hook("on_train_end")
|
||||
self.trainer._call_strategy_hook("on_train_end")
|
||||
|
||||
# give accelerators a chance to finish
|
||||
self.trainer.training_type_plugin.on_train_end()
|
||||
|
|
|
@ -102,14 +102,14 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
|
|||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
|
||||
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
|
||||
self.trainer.training_type_plugin.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
||||
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
|
||||
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
|
||||
training_step_output = ttp_output if model_output is None else model_output
|
||||
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
|
||||
training_step_output = strategy_output if model_output is None else model_output
|
||||
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
|
||||
|
||||
result = self.output_result_cls.from_training_step_output(training_step_output)
|
||||
|
|
|
@ -318,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
return None
|
||||
|
||||
def backward_fn(loss: Tensor) -> None:
|
||||
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
|
||||
self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
|
||||
|
||||
# check if model weights are nan
|
||||
if self.trainer._terminate_on_nan:
|
||||
|
@ -400,7 +400,9 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
optimizer: the current optimizer
|
||||
opt_idx: the index of the current optimizer
|
||||
"""
|
||||
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
|
||||
self.trainer._call_strategy_hook(
|
||||
"optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx
|
||||
)
|
||||
self.optim_progress.optimizer.zero_grad.increment_completed()
|
||||
|
||||
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
|
||||
|
@ -424,14 +426,14 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
|
|||
)
|
||||
|
||||
# manually capture logged metrics
|
||||
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
|
||||
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
|
||||
self.trainer.training_type_plugin.post_training_step()
|
||||
|
||||
del step_kwargs
|
||||
|
||||
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
|
||||
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
|
||||
training_step_output = ttp_output if model_output is None else model_output
|
||||
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
|
||||
training_step_output = strategy_output if model_output is None else model_output
|
||||
|
||||
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
|
||||
|
||||
|
|
|
@ -1570,8 +1570,7 @@ class Trainer(
|
|||
# restore current_fx when nested context
|
||||
pl_module._current_fx_name = prev_fx_name
|
||||
|
||||
# TODO: rename to _call_strategy_hook and eventually no longer need this
|
||||
def _call_ttp_hook(
|
||||
def _call_strategy_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
*args: Any,
|
||||
|
|
Loading…
Reference in New Issue