diff --git a/pl_examples/loop_examples/yielding_training_step.py b/pl_examples/loop_examples/yielding_training_step.py index 69c84e15c9..739d4f0f2b 100644 --- a/pl_examples/loop_examples/yielding_training_step.py +++ b/pl_examples/loop_examples/yielding_training_step.py @@ -89,8 +89,8 @@ class YieldLoop(OptimizerLoop): self.trainer.training_type_plugin.post_training_step() model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output) - ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output) - training_step_output = ttp_output if model_output is None else model_output + strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) + training_step_output = strategy_output if model_output is None else model_output # The closure result takes care of properly detaching the loss for logging and peforms # some additional checks that the output format is correct. diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2954927196..688cbdbf59 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -200,11 +200,11 @@ class EvaluationLoop(DataLoaderLoop): if self.trainer.testing: self.trainer._call_callback_hooks("on_test_start", *args, **kwargs) self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs) - self.trainer._call_ttp_hook("on_test_start", *args, **kwargs) + self.trainer._call_strategy_hook("on_test_start", *args, **kwargs) else: self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs) self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs) - self.trainer._call_ttp_hook("on_validation_start", *args, **kwargs) + self.trainer._call_strategy_hook("on_validation_start", *args, **kwargs) def _on_evaluation_model_eval(self) -> None: """Sets model to eval mode.""" @@ -225,11 +225,11 @@ class EvaluationLoop(DataLoaderLoop): if self.trainer.testing: self.trainer._call_callback_hooks("on_test_end", *args, **kwargs) self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs) - self.trainer._call_ttp_hook("on_test_end", *args, **kwargs) + self.trainer._call_strategy_hook("on_test_end", *args, **kwargs) else: self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs) self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs) - self.trainer._call_ttp_hook("on_validation_end", *args, **kwargs) + self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs) # reset the logger connector state self.trainer.logger_connector.reset_results() diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 8a0b50a30a..3f227736d0 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -114,7 +114,7 @@ class PredictionLoop(DataLoaderLoop): # hook self.trainer._call_callback_hooks("on_predict_start") self.trainer._call_lightning_module_hook("on_predict_start") - self.trainer._call_ttp_hook("on_predict_start") + self.trainer._call_strategy_hook("on_predict_start") self.trainer._call_callback_hooks("on_predict_epoch_start") self.trainer._call_lightning_module_hook("on_predict_epoch_start") @@ -142,7 +142,7 @@ class PredictionLoop(DataLoaderLoop): # hook self.trainer._call_callback_hooks("on_predict_end") self.trainer._call_lightning_module_hook("on_predict_end") - self.trainer._call_ttp_hook("on_predict_end") + self.trainer._call_strategy_hook("on_predict_end") def _on_predict_model_eval(self) -> None: """Calls ``on_predict_model_eval`` hook.""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index bf45986a96..69af7133d3 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -112,7 +112,7 @@ class EvaluationEpochLoop(Loop): raise StopIteration if not data_fetcher.store_on_device: - batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0)) + batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0)) self.batch_progress.increment_ready() @@ -222,9 +222,9 @@ class EvaluationEpochLoop(Loop): the outputs of the step """ if self.trainer.testing: - output = self.trainer._call_ttp_hook("test_step", *kwargs.values()) + output = self.trainer._call_strategy_hook("test_step", *kwargs.values()) else: - output = self.trainer._call_ttp_hook("validation_step", *kwargs.values()) + output = self.trainer._call_strategy_hook("validation_step", *kwargs.values()) return output @@ -232,8 +232,8 @@ class EvaluationEpochLoop(Loop): """Calls the `{validation/test}_step_end` hook.""" hook_name = "test_step_end" if self.trainer.testing else "validation_step_end" model_output = self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs) - ttp_output = self.trainer._call_ttp_hook(hook_name, *args, **kwargs) - output = ttp_output if model_output is None else model_output + strategy_output = self.trainer._call_strategy_hook(hook_name, *args, **kwargs) + output = strategy_output if model_output is None else model_output return output def _on_evaluation_batch_start(self, **kwargs: Any) -> None: diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index e9d0b85d35..3fb49e7d4b 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -96,7 +96,7 @@ class PredictionEpochLoop(Loop): if batch is None: raise StopIteration - batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=dataloader_idx) + batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=dataloader_idx) self.batch_progress.increment_ready() @@ -128,7 +128,7 @@ class PredictionEpochLoop(Loop): self.batch_progress.increment_started() - predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values()) + predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values()) self.batch_progress.increment_processed() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 2689f12088..c001a6de47 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -156,7 +156,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter) if not data_fetcher.store_on_device: - batch = self.trainer._call_ttp_hook("batch_to_device", batch) + batch = self.trainer._call_strategy_hook("batch_to_device", batch) self.batch_progress.increment_ready() @@ -182,7 +182,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): response = self.trainer._call_lightning_module_hook( "on_train_batch_start", batch, batch_idx, **extra_kwargs ) - self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) + self.trainer._call_strategy_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) if response == -1: self.batch_progress.increment_processed() raise StopIteration diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8f48696926..49b5a1ba5a 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -195,7 +195,7 @@ class FitLoop(Loop): self._results.to(device=self.trainer.lightning_module.device) self.trainer._call_callback_hooks("on_train_start") self.trainer._call_lightning_module_hook("on_train_start") - self.trainer._call_ttp_hook("on_train_start") + self.trainer._call_strategy_hook("on_train_start") def on_advance_start(self) -> None: # type: ignore[override] """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and @@ -252,7 +252,7 @@ class FitLoop(Loop): # hook self.trainer._call_callback_hooks("on_train_end") self.trainer._call_lightning_module_hook("on_train_end") - self.trainer._call_ttp_hook("on_train_end") + self.trainer._call_strategy_hook("on_train_end") # give accelerators a chance to finish self.trainer.training_type_plugin.on_train_end() diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 21efd02b7a..9577d9e15d 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -102,14 +102,14 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): ) # manually capture logged metrics - training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values()) + training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values()) self.trainer.training_type_plugin.post_training_step() del step_kwargs model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output) - ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output) - training_step_output = ttp_output if model_output is None else model_output + strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) + training_step_output = strategy_output if model_output is None else model_output self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps) result = self.output_result_cls.from_training_step_output(training_step_output) diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index c710aa31e3..d54d06ba53 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -318,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): return None def backward_fn(loss: Tensor) -> None: - self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx) + self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx) # check if model weights are nan if self.trainer._terminate_on_nan: @@ -400,7 +400,9 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): optimizer: the current optimizer opt_idx: the index of the current optimizer """ - self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.trainer._call_strategy_hook( + "optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx + ) self.optim_progress.optimizer.zero_grad.increment_completed() def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: @@ -424,14 +426,14 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): ) # manually capture logged metrics - training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values()) + training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values()) self.trainer.training_type_plugin.post_training_step() del step_kwargs model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output) - ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output) - training_step_output = ttp_output if model_output is None else model_output + strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output) + training_step_output = strategy_output if model_output is None else model_output self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3740be974b..523a4e76d5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1570,8 +1570,7 @@ class Trainer( # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name - # TODO: rename to _call_strategy_hook and eventually no longer need this - def _call_ttp_hook( + def _call_strategy_hook( self, hook_name: str, *args: Any,