rename _call_ttp_hook to _call_strategy_hook (#11150)

This commit is contained in:
Danielle Pintz 2021-12-18 17:53:03 -08:00 committed by GitHub
parent a3e2ef2be0
commit f95976d602
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 30 additions and 29 deletions

View File

@ -89,8 +89,8 @@ class YieldLoop(OptimizerLoop):
self.trainer.training_type_plugin.post_training_step()
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
training_step_output = ttp_output if model_output is None else model_output
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
training_step_output = strategy_output if model_output is None else model_output
# The closure result takes care of properly detaching the loss for logging and peforms
# some additional checks that the output format is correct.

View File

@ -200,11 +200,11 @@ class EvaluationLoop(DataLoaderLoop):
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_start", *args, **kwargs)
self.trainer._call_strategy_hook("on_test_start", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_start", *args, **kwargs)
self.trainer._call_strategy_hook("on_validation_start", *args, **kwargs)
def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
@ -225,11 +225,11 @@ class EvaluationLoop(DataLoaderLoop):
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_end", *args, **kwargs)
self.trainer._call_strategy_hook("on_test_end", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_end", *args, **kwargs)
self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs)
# reset the logger connector state
self.trainer.logger_connector.reset_results()

View File

@ -114,7 +114,7 @@ class PredictionLoop(DataLoaderLoop):
# hook
self.trainer._call_callback_hooks("on_predict_start")
self.trainer._call_lightning_module_hook("on_predict_start")
self.trainer._call_ttp_hook("on_predict_start")
self.trainer._call_strategy_hook("on_predict_start")
self.trainer._call_callback_hooks("on_predict_epoch_start")
self.trainer._call_lightning_module_hook("on_predict_epoch_start")
@ -142,7 +142,7 @@ class PredictionLoop(DataLoaderLoop):
# hook
self.trainer._call_callback_hooks("on_predict_end")
self.trainer._call_lightning_module_hook("on_predict_end")
self.trainer._call_ttp_hook("on_predict_end")
self.trainer._call_strategy_hook("on_predict_end")
def _on_predict_model_eval(self) -> None:
"""Calls ``on_predict_model_eval`` hook."""

View File

@ -112,7 +112,7 @@ class EvaluationEpochLoop(Loop):
raise StopIteration
if not data_fetcher.store_on_device:
batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0))
batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=(dataloader_idx or 0))
self.batch_progress.increment_ready()
@ -222,9 +222,9 @@ class EvaluationEpochLoop(Loop):
the outputs of the step
"""
if self.trainer.testing:
output = self.trainer._call_ttp_hook("test_step", *kwargs.values())
output = self.trainer._call_strategy_hook("test_step", *kwargs.values())
else:
output = self.trainer._call_ttp_hook("validation_step", *kwargs.values())
output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
return output
@ -232,8 +232,8 @@ class EvaluationEpochLoop(Loop):
"""Calls the `{validation/test}_step_end` hook."""
hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
model_output = self.trainer._call_lightning_module_hook(hook_name, *args, **kwargs)
ttp_output = self.trainer._call_ttp_hook(hook_name, *args, **kwargs)
output = ttp_output if model_output is None else model_output
strategy_output = self.trainer._call_strategy_hook(hook_name, *args, **kwargs)
output = strategy_output if model_output is None else model_output
return output
def _on_evaluation_batch_start(self, **kwargs: Any) -> None:

View File

@ -96,7 +96,7 @@ class PredictionEpochLoop(Loop):
if batch is None:
raise StopIteration
batch = self.trainer._call_ttp_hook("batch_to_device", batch, dataloader_idx=dataloader_idx)
batch = self.trainer._call_strategy_hook("batch_to_device", batch, dataloader_idx=dataloader_idx)
self.batch_progress.increment_ready()
@ -128,7 +128,7 @@ class PredictionEpochLoop(Loop):
self.batch_progress.increment_started()
predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values())
predictions = self.trainer._call_strategy_hook("predict_step", *step_kwargs.values())
self.batch_progress.increment_processed()

View File

@ -156,7 +156,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
if not data_fetcher.store_on_device:
batch = self.trainer._call_ttp_hook("batch_to_device", batch)
batch = self.trainer._call_strategy_hook("batch_to_device", batch)
self.batch_progress.increment_ready()
@ -182,7 +182,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
response = self.trainer._call_lightning_module_hook(
"on_train_batch_start", batch, batch_idx, **extra_kwargs
)
self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
self.trainer._call_strategy_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
self.batch_progress.increment_processed()
raise StopIteration

View File

@ -195,7 +195,7 @@ class FitLoop(Loop):
self._results.to(device=self.trainer.lightning_module.device)
self.trainer._call_callback_hooks("on_train_start")
self.trainer._call_lightning_module_hook("on_train_start")
self.trainer._call_ttp_hook("on_train_start")
self.trainer._call_strategy_hook("on_train_start")
def on_advance_start(self) -> None: # type: ignore[override]
"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and
@ -252,7 +252,7 @@ class FitLoop(Loop):
# hook
self.trainer._call_callback_hooks("on_train_end")
self.trainer._call_lightning_module_hook("on_train_end")
self.trainer._call_ttp_hook("on_train_end")
self.trainer._call_strategy_hook("on_train_end")
# give accelerators a chance to finish
self.trainer.training_type_plugin.on_train_end()

View File

@ -102,14 +102,14 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
)
# manually capture logged metrics
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
del step_kwargs
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
training_step_output = ttp_output if model_output is None else model_output
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
training_step_output = strategy_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
result = self.output_result_cls.from_training_step_output(training_step_output)

View File

@ -318,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
return None
def backward_fn(loss: Tensor) -> None:
self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx)
self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
# check if model weights are nan
if self.trainer._terminate_on_nan:
@ -400,7 +400,9 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
optimizer: the current optimizer
opt_idx: the index of the current optimizer
"""
self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.trainer._call_strategy_hook(
"optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx
)
self.optim_progress.optimizer.zero_grad.increment_completed()
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
@ -424,14 +426,14 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
)
# manually capture logged metrics
training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values())
training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
self.trainer.training_type_plugin.post_training_step()
del step_kwargs
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
ttp_output = self.trainer._call_ttp_hook("training_step_end", training_step_output)
training_step_output = ttp_output if model_output is None else model_output
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)
training_step_output = strategy_output if model_output is None else model_output
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

View File

@ -1570,8 +1570,7 @@ class Trainer(
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
# TODO: rename to _call_strategy_hook and eventually no longer need this
def _call_ttp_hook(
def _call_strategy_hook(
self,
hook_name: str,
*args: Any,