From f95976d602d294ea3943dadd6ca4e8f2ed7bab21 Mon Sep 17 00:00:00 2001
From: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com>
Date: Sat, 18 Dec 2021 17:53:03 -0800
Subject: [PATCH] rename _call_ttp_hook to _call_strategy_hook (#11150)

---
 pl_examples/loop_examples/yielding_training_step.py  |  4 ++--
 .../loops/dataloader/evaluation_loop.py              |  8 ++++----
 .../loops/dataloader/prediction_loop.py              |  4 ++--
 .../loops/epoch/evaluation_epoch_loop.py             | 10 +++++-----
 .../loops/epoch/prediction_epoch_loop.py             |  4 ++--
 pytorch_lightning/loops/epoch/training_epoch_loop.py |  4 ++--
 pytorch_lightning/loops/fit_loop.py                  |  4 ++--
 pytorch_lightning/loops/optimization/manual_loop.py  |  6 +++---
 .../loops/optimization/optimizer_loop.py             | 12 +++++++-----
 pytorch_lightning/trainer/trainer.py                 |  3 +--
 10 files changed, 30 insertions(+), 29 deletions(-)

diff --git a/pl_examples/loop_examples/yielding_training_step.py b/pl_examples/loop_examples/yielding_training_step.py
index 69c84e15c9..739d4f0f2b 100644
--- a/pl_examples/loop_examples/yielding_training_step.py
+++ b/pl_examples/loop_examples/yielding_training_step.py
@@ -89,8 +89,8 @@ class YieldLoop(OptimizerLoop):
         self.trainer.training_type_plugin.post_training_step()
 
         model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
-        ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
-        training_step_output = ttp_output if model_output is None else model_output
+        strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
+        training_step_output = strategy_output if model_output is None else model_output
 
         # The closure result takes care of properly detaching the loss for logging and peforms
         # some additional checks that the output format is correct.
diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py
index 2954927196..688cbdbf59 100644
--- a/pytorch_lightning/loops/dataloader/evaluation_loop.py
+++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py
@@ -200,11 +200,11 @@ class EvaluationLoop(DataLoaderLoop):
         if self.trainer.testing:
             self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
             self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
-            self.trainer._call_ttp_hook("on_test_start", *args, **kwargs)
+            self.trainer._call_strategy_hook("on_test_start", *args, **kwargs)
         else:
             self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
             self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
-            self.trainer._call_ttp_hook("on_validation_start", *args, **kwargs)
+            self.trainer._call_strategy_hook("on_validation_start", *args, **kwargs)
 
     def _on_evaluation_model_eval(self) -> None:
         """Sets model to eval mode."""
@@ -225,11 +225,11 @@ class EvaluationLoop(DataLoaderLoop):
         if self.trainer.testing:
             self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
             self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
-            self.trainer._call_ttp_hook("on_test_end", *args, **kwargs)
+            self.trainer._call_strategy_hook("on_test_end", *args, **kwargs)
         else:
             self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
             self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
-            self.trainer._call_ttp_hook("on_validation_end", *args, **kwargs)
+            self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs)
 
         # reset the logger connector state
         self.trainer.logger_connector.reset_results()
diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py
index 8a0b50a30a..3f227736d0 100644
--- a/pytorch_lightning/loops/dataloader/prediction_loop.py
+++ b/pytorch_lightning/loops/dataloader/prediction_loop.py
@@ -114,7 +114,7 @@ class PredictionLoop(DataLoaderLoop):
         # hook
         self.trainer._call_callback_hooks("on_predict_start")
         self.trainer._call_lightning_module_hook("on_predict_start")
-        self.trainer._call_ttp_hook("on_predict_start")
+        self.trainer._call_strategy_hook("on_predict_start")
 
         self.trainer._call_callback_hooks("on_predict_epoch_start")
         self.trainer._call_lightning_module_hook("on_predict_epoch_start")
@@ -142,7 +142,7 @@ class PredictionLoop(DataLoaderLoop):
         # hook
         self.trainer._call_callback_hooks("on_predict_end")
         self.trainer._call_lightning_module_hook("on_predict_end")
-        self.trainer._call_ttp_hook("on_predict_end")
+        self.trainer._call_strategy_hook("on_predict_end")
 
     def _on_predict_model_eval(self) -> None:
         """Calls ``on_predict_model_eval`` hook."""
diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
index bf45986a96..69af7133d3 100644
--- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
@@ -112,7 +112,7 @@ class EvaluationEpochLoop(Loop):
             raise StopIteration
 
         if not data_fetcher.store_on_device:
-            batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0))
+            batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0))
 
         self.batch_progress.increment_ready()
 
@@ -222,9 +222,9 @@ class EvaluationEpochLoop(Loop):
             the outputs of the step
         """
         if self.trainer.testing:
-            output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
+            output = self.trainer._call_strategy_hook("test_step", *kwargs.values())
         else:
-            output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
+            output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
 
         return output
 
@@ -232,8 +232,8 @@ class EvaluationEpochLoop(Loop):
         """Calls the `{validation/test}_step_end` hook."""
         hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
         model_output = self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs)
-        ttp_output = self.trainer._call_ttp_hook(hook_name, *args, **kwargs)
-        output = ttp_output if model_output is None else model_output
+        strategy_output = self.trainer._call_strategy_hook(hook_name, *args, **kwargs)
+        output = strategy_output if model_output is None else model_output
         return output
 
     def _on_evaluation_batch_start(self, **kwargs: Any) -> None:
diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
index e9d0b85d35..3fb49e7d4b 100644
--- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py
@@ -96,7 +96,7 @@ class PredictionEpochLoop(Loop):
         if batch is None:
             raise StopIteration
 
-        batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=dataloader_idx)
+        batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=dataloader_idx)
 
         self.batch_progress.increment_ready()
 
@@ -128,7 +128,7 @@ class PredictionEpochLoop(Loop):
 
         self.batch_progress.increment_started()
 
-        predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
+        predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
 
         self.batch_progress.increment_processed()
 
diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py
index 2689f12088..c001a6de47 100644
--- a/pytorch_lightning/loops/epoch/training_epoch_loop.py
+++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py
@@ -156,7 +156,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
         batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
 
         if not data_fetcher.store_on_device:
-            batch = self.trainer._call_ttp_hook("batch_to_device", batch)
+            batch = self.trainer._call_strategy_hook("batch_to_device", batch)
 
         self.batch_progress.increment_ready()
 
@@ -182,7 +182,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
             response = self.trainer._call_lightning_module_hook(
                 "on_train_batch_start", batch, batch_idx, **extra_kwargs
             )
-            self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
+            self.trainer._call_strategy_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
             if response == -1:
                 self.batch_progress.increment_processed()
                 raise StopIteration
diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py
index 8f48696926..49b5a1ba5a 100644
--- a/pytorch_lightning/loops/fit_loop.py
+++ b/pytorch_lightning/loops/fit_loop.py
@@ -195,7 +195,7 @@ class FitLoop(Loop):
         self._results.to(device=self.trainer.lightning_module.device)
         self.trainer._call_callback_hooks("on_train_start")
         self.trainer._call_lightning_module_hook("on_train_start")
-        self.trainer._call_ttp_hook("on_train_start")
+        self.trainer._call_strategy_hook("on_train_start")
 
     def on_advance_start(self) -> None:  # type: ignore[override]
         """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
@@ -252,7 +252,7 @@ class FitLoop(Loop):
         # hook
         self.trainer._call_callback_hooks("on_train_end")
         self.trainer._call_lightning_module_hook("on_train_end")
-        self.trainer._call_ttp_hook("on_train_end")
+        self.trainer._call_strategy_hook("on_train_end")
 
         # give accelerators a chance to finish
         self.trainer.training_type_plugin.on_train_end()
diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py
index 21efd02b7a..9577d9e15d 100644
--- a/pytorch_lightning/loops/optimization/manual_loop.py
+++ b/pytorch_lightning/loops/optimization/manual_loop.py
@@ -102,14 +102,14 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
             )
 
             # manually capture logged metrics
-            training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
+            training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
             self.trainer.training_type_plugin.post_training_step()
 
             del step_kwargs
 
             model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
-            ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
-            training_step_output = ttp_output if model_output is None else model_output
+            strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
+            training_step_output = strategy_output if model_output is None else model_output
             self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
 
             result = self.output_result_cls.from_training_step_output(training_step_output)
diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py
index c710aa31e3..d54d06ba53 100644
--- a/pytorch_lightning/loops/optimization/optimizer_loop.py
+++ b/pytorch_lightning/loops/optimization/optimizer_loop.py
@@ -318,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
             return None
 
         def backward_fn(loss: Tensor) -> None:
-            self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
+            self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
 
             # check if model weights are nan
             if self.trainer._terminate_on_nan:
@@ -400,7 +400,9 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
             optimizer: the current optimizer
             opt_idx: the index of the current optimizer
         """
-        self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
+        self.trainer._call_strategy_hook(
+            "optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx
+        )
         self.optim_progress.optimizer.zero_grad.increment_completed()
 
     def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
@@ -424,14 +426,14 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
             )
 
             # manually capture logged metrics
-            training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
+            training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
             self.trainer.training_type_plugin.post_training_step()
 
             del step_kwargs
 
             model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
-            ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
-            training_step_output = ttp_output if model_output is None else model_output
+            strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
+            training_step_output = strategy_output if model_output is None else model_output
 
             self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
 
diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py
index 3740be974b..523a4e76d5 100644
--- a/pytorch_lightning/trainer/trainer.py
+++ b/pytorch_lightning/trainer/trainer.py
@@ -1570,8 +1570,7 @@ class Trainer(
             # restore current_fx when nested context
             pl_module._current_fx_name = prev_fx_name
 
-    # TODO: rename to _call_strategy_hook and eventually no longer need this
-    def _call_ttp_hook(
+    def _call_strategy_hook(
         self,
         hook_name: str,
         *args: Any,