From 4c07ab5e99dd20c1f309d9e73cdaacc1ebad9499 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 15 Apr 2021 15:10:34 -0700 Subject: [PATCH] Use PyTorch API logging for Lightning Trainer (#6771) * Update trainer.py * Update trainer.py * Update trainer.py --- pytorch_lightning/trainer/trainer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d4fc53f85..b1c29ff2c8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -295,7 +295,7 @@ class Trainer( """ super().__init__() - + Trainer._log_api_event("init") distributed_backend = distributed_backend or accelerator # init connectors @@ -416,6 +416,7 @@ class Trainer( If the model has a predefined val_dataloaders method this will be skipped """ + Trainer._log_api_event("fit") # we reuse fit for other functions. When already set, it shouldn't be modified. if not self.state.running: self.state = TrainerState.FITTING @@ -881,6 +882,7 @@ class Trainer( # -------------------- # SETUP HOOK # -------------------- + Trainer._log_api_event("validate") self.verbose_evaluate = verbose self.state = TrainerState.VALIDATING @@ -943,6 +945,7 @@ class Trainer( # -------------------- # SETUP HOOK # -------------------- + Trainer._log_api_event("test") self.verbose_evaluate = verbose self.state = TrainerState.TESTING @@ -1039,6 +1042,7 @@ class Trainer( # SETUP HOOK # -------------------- # If you supply a datamodule you can't supply dataloaders + Trainer._log_api_event("predict") model = model or self.lightning_module @@ -1084,6 +1088,7 @@ class Trainer( If the model has a predefined val_dataloaders method this will be skipped """ + Trainer._log_api_event("tune") self.state = TrainerState.TUNING self.tuning = True @@ -1174,3 +1179,7 @@ class Trainer( if not skip: self._cache_logged_metrics() return output + + @staticmethod + def _log_api_event(event: str) -> None: + torch._C._log_api_usage_once("lightning.trainer." + event)