Logger consistency (#397)
* added comet logger * bug fix in cases where comet was not imported before torch * fixed mlflow logger to be consistent with docs, updated cometLogger and cometLoggers docs + flake 8 compliance
This commit is contained in:
parent
1424157731
commit
ab6794406e
|
@ -72,7 +72,6 @@ mlf_logger = MLFlowLogger(
|
|||
)
|
||||
trainer = Trainer(logger=mlf_logger)
|
||||
```
|
||||
|
||||
Use the logger anywhere in you LightningModule as follows:
|
||||
```python
|
||||
def train_step(...):
|
||||
|
@ -83,6 +82,30 @@ def any_lightning_module_function_or_hook(...):
|
|||
self.logger.experiment.whatever_ml_flow_supports(...)
|
||||
```
|
||||
|
||||
---
|
||||
#### Comet.ml
|
||||
|
||||
Log using [comet](https://www.comet.ml)
|
||||
|
||||
```{.python}
|
||||
from pytorch_lightning.logging import CometLogger
|
||||
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
|
||||
comet_logger = CometLogger(
|
||||
api_key=os.environ["COMET_KEY"],
|
||||
workspace=os.environ["COMET_KEY"],
|
||||
)
|
||||
trainer = Trainer(logger=comet_logger)
|
||||
```
|
||||
Use the logger anywhere in you LightningModule as follows:
|
||||
```python
|
||||
def train_step(...):
|
||||
# example
|
||||
self.logger.experiment.whatever_comet_ml_supports(...)
|
||||
|
||||
def any_lightning_module_function_or_hook(...):
|
||||
self.logger.experiment.whatever_comet_ml_supports(...)
|
||||
```
|
||||
|
||||
---
|
||||
#### Custom logger
|
||||
|
||||
|
|
|
@ -8,3 +8,7 @@ try:
|
|||
from .mlflow_logger import MLFlowLogger
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
try:
|
||||
from .comet_logger import CometLogger
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
from time import time
|
||||
from logging import getLogger
|
||||
from os import environ
|
||||
from comet_ml import Experiment as CometExperiment
|
||||
from .base import LightningLoggerBase, rank_zero_only
|
||||
|
||||
# needed to prevent ImportError and duplicated logs.
|
||||
environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
|
||||
|
||||
|
||||
class CometLogger(LightningLoggerBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CometLogger, self).__init__()
|
||||
self.experiment = CometExperiment(*args, **kwargs)
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params):
|
||||
self.experiment.log_parameters(vars(params))
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step_num):
|
||||
# self.experiment.set_epoch(self, metrics.get('epoch', 0))
|
||||
self.experiment.log_metrics(metrics)
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status):
|
||||
self.experiment.end()
|
|
@ -11,7 +11,7 @@ logger = getLogger(__name__)
|
|||
class MLFlowLogger(LightningLoggerBase):
|
||||
def __init__(self, experiment_name, tracking_uri=None, tags=None):
|
||||
super().__init__()
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri)
|
||||
self.experiment = mlflow.tracking.MlflowClient(tracking_uri)
|
||||
self.experiment_name = experiment_name
|
||||
self._run_id = None
|
||||
self.tags = tags
|
||||
|
@ -21,22 +21,22 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
if self._run_id is not None:
|
||||
return self._run_id
|
||||
|
||||
experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
experiment = self.experiment.get_experiment_by_name(self.experiment_name)
|
||||
if experiment is None:
|
||||
logger.warning(
|
||||
f"Experiment with name f{self.experiment_name} not found. Creating it."
|
||||
)
|
||||
self.client.create_experiment(self.experiment_name)
|
||||
experiment = self.client.get_experiment_by_name(self.experiment_name)
|
||||
self.experiment.create_experiment(self.experiment_name)
|
||||
experiment = self.experiment.get_experiment_by_name(self.experiment_name)
|
||||
|
||||
run = self.client.create_run(experiment.experiment_id, tags=self.tags)
|
||||
run = self.experiment.create_run(experiment.experiment_id, tags=self.tags)
|
||||
self._run_id = run.info.run_id
|
||||
return self._run_id
|
||||
|
||||
@rank_zero_only
|
||||
def log_hyperparams(self, params):
|
||||
for k, v in vars(params).items():
|
||||
self.client.log_param(self.run_id, k, v)
|
||||
self.experiment.log_param(self.run_id, k, v)
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step_num=None):
|
||||
|
@ -47,7 +47,7 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
f"Discarding metric with string value {k}={v}"
|
||||
)
|
||||
continue
|
||||
self.client.log_metric(self.run_id, k, v, timestamp_ms, step_num)
|
||||
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step_num)
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
@ -56,4 +56,4 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
def finalize(self, status="FINISHED"):
|
||||
if status == 'success':
|
||||
status = 'FINISHED'
|
||||
self.client.set_terminated(self.run_id, status)
|
||||
self.experiment.set_terminated(self.run_id, status)
|
||||
|
|
Loading…
Reference in New Issue