diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d4fc53f85..b1c29ff2c8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -295,7 +295,7 @@ class Trainer( """ super().__init__() - + Trainer._log_api_event("init") distributed_backend = distributed_backend or accelerator # init connectors @@ -416,6 +416,7 @@ class Trainer( If the model has a predefined val_dataloaders method this will be skipped """ + Trainer._log_api_event("fit") # we reuse fit for other functions. When already set, it shouldn't be modified. if not self.state.running: self.state = TrainerState.FITTING @@ -881,6 +882,7 @@ class Trainer( # -------------------- # SETUP HOOK # -------------------- + Trainer._log_api_event("validate") self.verbose_evaluate = verbose self.state = TrainerState.VALIDATING @@ -943,6 +945,7 @@ class Trainer( # -------------------- # SETUP HOOK # -------------------- + Trainer._log_api_event("test") self.verbose_evaluate = verbose self.state = TrainerState.TESTING @@ -1039,6 +1042,7 @@ class Trainer( # SETUP HOOK # -------------------- # If you supply a datamodule you can't supply dataloaders + Trainer._log_api_event("predict") model = model or self.lightning_module @@ -1084,6 +1088,7 @@ class Trainer( If the model has a predefined val_dataloaders method this will be skipped """ + Trainer._log_api_event("tune") self.state = TrainerState.TUNING self.tuning = True @@ -1174,3 +1179,7 @@ class Trainer( if not skip: self._cache_logged_metrics() return output + + @staticmethod + def _log_api_event(event: str) -> None: + torch._C._log_api_usage_once("lightning.trainer." + event)