Use PyTorch API logging for Lightning Trainer (#6771)

* Update trainer.py

* Update trainer.py

* Update trainer.py
This commit is contained in:
ananthsub 2021-04-15 15:10:34 -07:00 committed by GitHub
parent f29ecbfd90
commit 4c07ab5e99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 1 deletions

View File

@ -295,7 +295,7 @@ class Trainer(
"""
super().__init__()
Trainer._log_api_event("init")
distributed_backend = distributed_backend or accelerator
# init connectors
@ -416,6 +416,7 @@ class Trainer(
If the model has a predefined val_dataloaders method this will be skipped
"""
Trainer._log_api_event("fit")
# we reuse fit for other functions. When already set, it shouldn't be modified.
if not self.state.running:
self.state = TrainerState.FITTING
@ -881,6 +882,7 @@ class Trainer(
# --------------------
# SETUP HOOK
# --------------------
Trainer._log_api_event("validate")
self.verbose_evaluate = verbose
self.state = TrainerState.VALIDATING
@ -943,6 +945,7 @@ class Trainer(
# --------------------
# SETUP HOOK
# --------------------
Trainer._log_api_event("test")
self.verbose_evaluate = verbose
self.state = TrainerState.TESTING
@ -1039,6 +1042,7 @@ class Trainer(
# SETUP HOOK
# --------------------
# If you supply a datamodule you can't supply dataloaders
Trainer._log_api_event("predict")
model = model or self.lightning_module
@ -1084,6 +1088,7 @@ class Trainer(
If the model has a predefined val_dataloaders method this will be skipped
"""
Trainer._log_api_event("tune")
self.state = TrainerState.TUNING
self.tuning = True
@ -1174,3 +1179,7 @@ class Trainer(
if not skip:
self._cache_logged_metrics()
return output
@staticmethod
def _log_api_event(event: str) -> None:
torch._C._log_api_usage_once("lightning.trainer." + event)