diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 767358f973..f77b4567aa 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -186,11 +186,10 @@ class EvaluationLoop(DataLoaderLoop): def _on_evaluation_model_train(self) -> None: """Sets model to train mode.""" - model_ref = self.trainer.lightning_module if self.trainer.testing: - model_ref.on_test_model_train() + self.trainer._call_lightning_module_hook("on_test_model_train") else: - model_ref.on_validation_model_train() + self.trainer._call_lightning_module_hook("on_validation_model_train") def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: """Runs ``on_{validation/test}_end`` hook.""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 2d467f6b6f..5f0a72588b 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -218,9 +218,9 @@ class EvaluationEpochLoop(Loop): the outputs of the step """ if self.trainer.testing: - output = self.trainer._call_accelerator_hook("test_step", *kwargs.values()) + output = self.trainer._call_ttp_hook("test_step", *kwargs.values()) else: - output = self.trainer._call_accelerator_hook("validation_step", *kwargs.values()) + output = self.trainer._call_ttp_hook("validation_step", *kwargs.values()) return output diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 4f34fb65c4..3fb74911d3 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -130,7 +130,7 @@ class PredictionEpochLoop(Loop): self.batch_progress.increment_started() - predictions = self.trainer._call_accelerator_hook("predict_step", *step_kwargs.values()) + predictions = self.trainer._call_ttp_hook("predict_step", *step_kwargs.values()) self.batch_progress.increment_processed() diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index ad4640bb57..3cf63aeba3 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -180,11 +180,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): # hook self.trainer._call_callback_hooks("on_train_batch_start", batch, batch_idx, **extra_kwargs) - model_response = self.trainer._call_lightning_module_hook( + response = self.trainer._call_lightning_module_hook( "on_train_batch_start", batch, batch_idx, **extra_kwargs ) - ttp_response = self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) - response = ttp_response if model_response is None else model_response + self.trainer._call_ttp_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) if response == -1: self.batch_progress.increment_processed() raise StopIteration diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 578d80f3b8..21efd02b7a 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -102,7 +102,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): ) # manually capture logged metrics - training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values()) + training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values()) self.trainer.training_type_plugin.post_training_step() del step_kwargs diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index f0a9dc7314..c710aa31e3 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -144,12 +144,10 @@ class Closure(AbstractClosure[ClosureResult]): ) if self._zero_grad_fn is not None: - with self._profiler.profile("zero_grad"): - self._zero_grad_fn() + self._zero_grad_fn() if self._backward_fn is not None and step_output.closure_loss is not None: - with self._profiler.profile("backward"): - self._backward_fn(step_output.closure_loss) + self._backward_fn(step_output.closure_loss) return step_output @@ -320,7 +318,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): return None def backward_fn(loss: Tensor) -> None: - self.trainer.training_type_plugin.backward(loss, optimizer, opt_idx) + self.trainer._call_ttp_hook("backward", loss, optimizer, opt_idx) # check if model weights are nan if self.trainer._terminate_on_nan: @@ -362,8 +360,6 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): train_step_and_backward_closure: the closure function performing the train step and computing the gradients. By default called by the optimizer (if possible) """ - lightning_module = self.trainer.lightning_module - is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) # wraps into LightningOptimizer only for running step optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) @@ -371,7 +367,8 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): self.optim_progress.optimizer.step.increment_ready() # model hook - lightning_module.optimizer_step( + self.trainer._call_lightning_module_hook( + "optimizer_step", self.trainer.current_epoch, batch_idx, optimizer, @@ -403,7 +400,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): optimizer: the current optimizer opt_idx: the index of the current optimizer """ - self.trainer.training_type_plugin.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.trainer._call_ttp_hook("optimizer_zero_grad", self.trainer.current_epoch, batch_idx, optimizer, opt_idx) self.optim_progress.optimizer.zero_grad.increment_completed() def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult: @@ -427,7 +424,7 @@ class OptimizerLoop(Loop[_OUTPUTS_TYPE]): ) # manually capture logged metrics - training_step_output = self.trainer._call_accelerator_hook("training_step", *step_kwargs.values()) + training_step_output = self.trainer._call_ttp_hook("training_step", *step_kwargs.values()) self.trainer.training_type_plugin.post_training_step() del step_kwargs diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index fc5de98636..ee1c41764a 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -164,8 +164,7 @@ class TrainingTypePlugin(ABC): def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients.""" - model_ref = self.lightning_module - model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) + self.lightning_module.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index ad3dce3c12..e73bf54825 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -31,15 +31,24 @@ class _FxValidator: "on_before_backward": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), + "backward": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), "on_after_backward": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), "on_before_optimizer_step": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), + "optimizer_step": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), "on_before_zero_grad": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), + "optimizer_zero_grad": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), "on_init_start": None, "on_init_end": None, "on_fit_start": None, @@ -160,6 +169,8 @@ class _FxValidator: "configure_callbacks": None, "on_validation_model_eval": None, "on_test_model_eval": None, + "on_validation_model_train": None, + "on_test_model_train": None, } @classmethod diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 88f929531e..bb9398570b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1452,7 +1452,7 @@ class Trainer( *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any, - ): + ) -> Any: pl_module = pl_module or self.lightning_module if pl_module is None: @@ -1460,7 +1460,7 @@ class Trainer( fn = getattr(pl_module, hook_name) if not callable(fn): - return None + return prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = hook_name @@ -1479,16 +1479,15 @@ class Trainer( hook_name: str, *args: Any, **kwargs: Any, - ) -> Optional[Any]: - output = None + ) -> None: if hook_name in ("on_init_start", "on_init_end"): # these `Callback` hooks are the only ones that do not take a lightning module. # we also don't profile bc profiler hasn't been set yet for callback in self.callbacks: fn = getattr(callback, hook_name) if callable(fn): - output = fn(self, *args, **kwargs) - return output + fn(self, *args, **kwargs) + return pl_module = self.lightning_module if pl_module: @@ -1500,34 +1499,39 @@ class Trainer( fn = getattr(self, hook_name) if callable(fn): with self.profiler.profile(hook_name): - output = fn(*args, **kwargs) + fn(*args, **kwargs) else: for callback in self.callbacks: fn = getattr(callback, hook_name) if callable(fn): with self.profiler.profile(hook_name): - output = fn(self, self.lightning_module, *args, **kwargs) + fn(self, self.lightning_module, *args, **kwargs) if pl_module: # restore current_fx when nested context pl_module._current_fx_name = prev_fx_name - return output - # TODO: rename to _call_strategy_hook and eventually no longer need this def _call_ttp_hook( self, hook_name: str, *args: Any, **kwargs: Any, - ): + ) -> Any: + pl_module = self.lightning_module + prev_fx_name = pl_module._current_fx_name + pl_module._current_fx_name = hook_name + fn = getattr(self.training_type_plugin, hook_name) if not callable(fn): - return None + return with self.profiler.profile(hook_name): output = fn(*args, **kwargs) + # restore current_fx when nested context + pl_module._current_fx_name = prev_fx_name + return output # TODO: eventually no longer need this @@ -1536,15 +1540,21 @@ class Trainer( hook_name: str, *args: Any, **kwargs: Any, - ) -> Optional[Any]: - self.lightning_module._current_fx_name = hook_name - fn = getattr(self.training_type_plugin, hook_name) + ) -> Any: + pl_module = self.lightning_module + prev_fx_name = pl_module._current_fx_name + pl_module._current_fx_name = hook_name + + fn = getattr(self.accelerator, hook_name) if not callable(fn): - return None + return with self.profiler.profile(hook_name): output = fn(*args, **kwargs) + # restore current_fx when nested context + pl_module._current_fx_name = prev_fx_name + return output @staticmethod diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b3b6667e97..cebf4cb6f9 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -31,9 +31,9 @@ from tests.models.test_hooks import get_members def test_fx_validator(tmpdir): - funcs_name = sorted(get_members(Callback)) + funcs_name = get_members(Callback) - callbacks_func = [ + callbacks_func = { "on_before_backward", "on_after_backward", "on_before_optimizer_step", @@ -82,9 +82,9 @@ def test_fx_validator(tmpdir): "on_predict_start", "setup", "teardown", - ] + } - not_supported = [ + not_supported = { "on_before_accelerator_backend_setup", "on_fit_end", "on_fit_start", @@ -110,11 +110,10 @@ def test_fx_validator(tmpdir): "on_validation_end", "setup", "teardown", - ] + } - assert funcs_name == sorted( - callbacks_func - ), "Detected new callback function. Need to add its logging permission to FxValidator and update this test" + # Detected new callback function. Need to add its logging permission to FxValidator and update this test + assert funcs_name == callbacks_func validator = _FxValidator() @@ -233,6 +232,7 @@ def test_fx_validator_integration(tmpdir): "prepare_data": "You can't", "configure_callbacks": "You can't", "on_validation_model_eval": "You can't", + "on_validation_model_train": "You can't", "summarize": "not managed by the `Trainer", } model = HookedModel(not_supported) @@ -260,6 +260,7 @@ def test_fx_validator_integration(tmpdir): "on_test_dataloader": "You can't", "test_dataloader": "You can't", "on_test_model_eval": "You can't", + "on_test_model_train": "You can't", "on_test_end": "You can't", } ) diff --git a/tests/trainer/logging_/test_loop_logging.py b/tests/trainer/logging_/test_loop_logging.py index 2c2f2253c4..40ad1bc48d 100644 --- a/tests/trainer/logging_/test_loop_logging.py +++ b/tests/trainer/logging_/test_loop_logging.py @@ -50,9 +50,12 @@ def test_default_level_for_hooks_that_support_logging(): trainer.state.stage = RunningStage.TRAINING hooks = [ "on_before_backward", + "backward", "on_after_backward", "on_before_optimizer_step", + "optimizer_step", "on_before_zero_grad", + "optimizer_zero_grad", "training_step", "training_step_end", "on_batch_start",