Allow kwargs in Wandb & Neptune + kwargs docstring (#3475)
* Allow kwargs in WandbLogger * isort * kwargs docstring * typo * kwargs for other loggers * pep and isort * formatting * fix failing test Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
8eb77cd06a
commit
07b857769a
|
@ -18,13 +18,13 @@ import operator
|
|||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from functools import wraps
|
||||
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple, MutableMapping
|
||||
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
class LightningLoggerBase(ABC):
|
||||
|
|
|
@ -18,15 +18,14 @@ Comet
|
|||
"""
|
||||
|
||||
import os
|
||||
|
||||
from argparse import Namespace
|
||||
from typing import Optional, Dict, Union, Any
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
try:
|
||||
from comet_ml import Experiment as CometExperiment
|
||||
from comet_ml import ExistingExperiment as CometExistingExperiment
|
||||
from comet_ml import OfflineExperiment as CometOfflineExperiment
|
||||
from comet_ml import BaseExperiment as CometBaseExperiment
|
||||
from comet_ml import ExistingExperiment as CometExistingExperiment
|
||||
from comet_ml import Experiment as CometExperiment
|
||||
from comet_ml import OfflineExperiment as CometOfflineExperiment
|
||||
from comet_ml import generate_guid
|
||||
|
||||
try:
|
||||
|
@ -34,7 +33,7 @@ try:
|
|||
except ImportError: # pragma: no-cover
|
||||
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
|
||||
from comet_ml.papi import API # pragma: no-cover
|
||||
from comet_ml.config import get_config, get_api_key
|
||||
from comet_ml.config import get_api_key, get_config
|
||||
except ImportError: # pragma: no-cover
|
||||
CometExperiment = None
|
||||
CometExistingExperiment = None
|
||||
|
@ -51,8 +50,8 @@ from torch import is_tensor
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class CometLogger(LightningLoggerBase):
|
||||
|
@ -102,7 +101,6 @@ class CometLogger(LightningLoggerBase):
|
|||
if either exists.
|
||||
save_dir: Required in offline mode. The path for the directory to save local
|
||||
comet logs. If given, this also sets the directory for saving checkpoints.
|
||||
workspace: Optional. Name of workspace for this user
|
||||
project_name: Optional. Send your experiment to a specific project.
|
||||
Otherwise will be sent to Uncategorized Experiments.
|
||||
If the project name does not already exist, Comet.ml will create a new project.
|
||||
|
@ -114,21 +112,21 @@ class CometLogger(LightningLoggerBase):
|
|||
the experiment will be in online or offline mode. This is useful if you use
|
||||
save_dir to control the checkpoints directory and have a ~/.comet.config
|
||||
file but still want to run offline experiments.
|
||||
\**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by
|
||||
:class:`CometExperiment` can be passed as keyword arguments in this logger.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
rest_api_key: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
experiment_key: Optional[str] = None,
|
||||
offline: bool = False,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if not _COMET_AVAILABLE:
|
||||
raise ImportError(
|
||||
"You want to use `comet_ml` logger which is not installed yet,"
|
||||
|
@ -157,7 +155,6 @@ class CometLogger(LightningLoggerBase):
|
|||
|
||||
log.info(f"CometLogger will be initialized in {self.mode} mode")
|
||||
|
||||
self.workspace = workspace
|
||||
self._project_name = project_name
|
||||
self._experiment_key = experiment_key
|
||||
self._experiment_name = experiment_name
|
||||
|
@ -197,13 +194,14 @@ class CometLogger(LightningLoggerBase):
|
|||
if self.mode == "online":
|
||||
if self._experiment_key is None:
|
||||
self._experiment = CometExperiment(
|
||||
api_key=self.api_key, workspace=self.workspace, project_name=self._project_name, **self._kwargs
|
||||
api_key=self.api_key,
|
||||
project_name=self._project_name,
|
||||
**self._kwargs,
|
||||
)
|
||||
self._experiment_key = self._experiment.get_key()
|
||||
else:
|
||||
self._experiment = CometExistingExperiment(
|
||||
api_key=self.api_key,
|
||||
workspace=self.workspace,
|
||||
project_name=self._project_name,
|
||||
previous_experiment=self._experiment_key,
|
||||
**self._kwargs,
|
||||
|
@ -211,7 +209,6 @@ class CometLogger(LightningLoggerBase):
|
|||
else:
|
||||
self._experiment = CometOfflineExperiment(
|
||||
offline_directory=self.save_dir,
|
||||
workspace=self.workspace,
|
||||
project_name=self._project_name,
|
||||
**self._kwargs,
|
||||
)
|
||||
|
|
|
@ -23,14 +23,14 @@ import csv
|
|||
import io
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
|
||||
|
||||
|
||||
class ExperimentWriter(object):
|
||||
|
@ -116,11 +116,12 @@ class CSVLogger(LightningLoggerBase):
|
|||
directory for existing versions, then automatically assigns the next available version.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
name: Optional[str] = "default",
|
||||
version: Optional[Union[int, str]] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
name: Optional[str] = "default",
|
||||
version: Optional[Union[int, str]] = None
|
||||
):
|
||||
super().__init__()
|
||||
self._save_dir = save_dir
|
||||
self._name = name or ''
|
||||
|
|
|
@ -18,7 +18,7 @@ MLflow
|
|||
"""
|
||||
from argparse import Namespace
|
||||
from time import time
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
try:
|
||||
import mlflow
|
||||
|
@ -34,7 +34,6 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
||||
LOCAL_FILE_URI_PREFIX = "file:"
|
||||
|
||||
|
||||
|
@ -77,12 +76,13 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
experiment_name: str = 'default',
|
||||
tracking_uri: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
save_dir: Optional[str] = './mlruns'):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str = 'default',
|
||||
tracking_uri: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
save_dir: Optional[str] = './mlruns'
|
||||
):
|
||||
if not _MLFLOW_AVAILABLE:
|
||||
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
|
||||
' install it with `pip install mlflow`.')
|
||||
|
|
|
@ -17,8 +17,7 @@ Neptune
|
|||
-------
|
||||
"""
|
||||
from argparse import Namespace
|
||||
from typing import Optional, List, Dict, Any, Union, Iterable
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
try:
|
||||
import neptune
|
||||
|
@ -159,41 +158,19 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
experiment_name: Optional. Editable name of the experiment.
|
||||
Name is displayed in the experiment’s Details (Metadata section) and
|
||||
in experiments view as a column.
|
||||
upload_source_files: Optional. List of source files to be uploaded.
|
||||
Must be list of str or single str. Uploaded sources are displayed
|
||||
in the experiment’s Source code tab.
|
||||
If ``None`` is passed, the Python file from which the experiment was created will be uploaded.
|
||||
Pass an empty list (``[]``) to upload no files.
|
||||
Unix style pathname pattern expansion is supported.
|
||||
For example, you can pass ``'\*.py'``
|
||||
to upload all python source files from the current directory.
|
||||
For recursion lookup use ``'\**/\*.py'`` (for Python 3.5 and later).
|
||||
For more information see :mod:`glob` library.
|
||||
params: Optional. Parameters of the experiment.
|
||||
After experiment creation params are read-only.
|
||||
Parameters are displayed in the experiment’s Parameters section and
|
||||
each key-value pair can be viewed in the experiments view as a column.
|
||||
properties: Optional. Default is ``{}``. Properties of the experiment.
|
||||
They are editable after the experiment is created.
|
||||
Properties are displayed in the experiment’s Details section and
|
||||
each key-value pair can be viewed in the experiments view as a column.
|
||||
tags: Optional. Default is ``[]``. Must be list of str. Tags of the experiment.
|
||||
They are editable after the experiment is created (see: ``append_tag()`` and ``remove_tag()``).
|
||||
Tags are displayed in the experiment’s Details section and can be viewed
|
||||
in the experiments view as a column.
|
||||
\**kwargs: Additional arguments like `params`, `tags`, `properties`, etc. used by
|
||||
:func:`neptune.Session.create_experiment` can be passed as keyword arguments in this logger.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
api_key: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
close_after_fit: Optional[bool] = True,
|
||||
offline_mode: bool = False,
|
||||
experiment_name: Optional[str] = None,
|
||||
upload_source_files: Optional[List[str]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
close_after_fit: Optional[bool] = True,
|
||||
offline_mode: bool = False,
|
||||
experiment_name: Optional[str] = None,
|
||||
**kwargs
|
||||
):
|
||||
if not _NEPTUNE_AVAILABLE:
|
||||
raise ImportError('You want to use `neptune` logger which is not installed yet,'
|
||||
' install it with `pip install neptune-client`.')
|
||||
|
@ -203,10 +180,6 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
self.offline_mode = offline_mode
|
||||
self.close_after_fit = close_after_fit
|
||||
self.experiment_name = experiment_name
|
||||
self.upload_source_files = upload_source_files
|
||||
self.params = params
|
||||
self.properties = properties
|
||||
self.tags = tags
|
||||
self._kwargs = kwargs
|
||||
self._experiment_id = None
|
||||
self._experiment = self._create_or_get_experiment()
|
||||
|
@ -391,13 +364,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
project = session.get_project(self.project_name)
|
||||
|
||||
if self._experiment_id is None:
|
||||
exp = project.create_experiment(
|
||||
name=self.experiment_name,
|
||||
params=self.params,
|
||||
properties=self.properties,
|
||||
tags=self.tags,
|
||||
upload_source_files=self.upload_source_files,
|
||||
**self._kwargs)
|
||||
exp = project.create_experiment(name=self.experiment_name, **self._kwargs)
|
||||
else:
|
||||
exp = project.get_experiments(id=self._experiment_id)[0]
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ TensorBoard
|
|||
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Optional, Dict, Union, Any
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
|
@ -27,11 +27,11 @@ from pkg_resources import parse_version
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
try:
|
||||
from omegaconf import Container, OmegaConf
|
||||
|
@ -67,7 +67,8 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
model.
|
||||
default_hp_metric: Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
|
||||
called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
|
||||
\**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor.
|
||||
\**kwargs: Additional arguments like `comment`, `filename_suffix`, etc. used by
|
||||
:class:`SummaryWriter` can be passed as keyword arguments in this logger.
|
||||
|
||||
"""
|
||||
NAME_HPARAMS_FILE = 'hparams.yaml'
|
||||
|
|
|
@ -17,7 +17,7 @@ Test Tube
|
|||
---------
|
||||
"""
|
||||
from argparse import Namespace
|
||||
from typing import Optional, Dict, Any, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
try:
|
||||
from test_tube import Experiment
|
||||
|
@ -26,9 +26,9 @@ except ImportError: # pragma: no-cover
|
|||
Experiment = None
|
||||
_TEST_TUBE_AVAILABLE = False
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
|
||||
class TestTubeLogger(LightningLoggerBase):
|
||||
|
|
|
@ -18,7 +18,7 @@ Weights and Biases
|
|||
"""
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Optional, List, Dict, Union, Any
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -36,7 +36,7 @@ from pytorch_lightning.utilities import rank_zero_only
|
|||
|
||||
|
||||
class WandbLogger(LightningLoggerBase):
|
||||
"""
|
||||
r"""
|
||||
Log using `Weights and Biases <https://www.wandb.com/>`_. Install it with pip:
|
||||
|
||||
.. code-block:: bash
|
||||
|
@ -51,11 +51,10 @@ class WandbLogger(LightningLoggerBase):
|
|||
anonymous: Enables or explicitly disables anonymous logging.
|
||||
version: Sets the version, mainly used to resume a previous run.
|
||||
project: The name of the project to which this run will belong.
|
||||
tags: Tags associated with this run.
|
||||
log_model: Save checkpoints in wandb dir to upload on W&B servers.
|
||||
experiment: WandB experiment object
|
||||
entity: The team posting this run (default: your username or your default team)
|
||||
group: A unique string shared by all runs in a given group
|
||||
experiment: WandB experiment object.
|
||||
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
|
||||
:func:`wandb.init` can be passed as keyword arguments in this logger.
|
||||
|
||||
Example:
|
||||
>>> from pytorch_lightning.loggers import WandbLogger
|
||||
|
@ -70,19 +69,19 @@ class WandbLogger(LightningLoggerBase):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: Optional[str] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
offline: bool = False,
|
||||
id: Optional[str] = None,
|
||||
anonymous: bool = False,
|
||||
version: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
log_model: bool = False,
|
||||
experiment=None,
|
||||
entity=None,
|
||||
group: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
offline: bool = False,
|
||||
id: Optional[str] = None,
|
||||
anonymous: bool = False,
|
||||
version: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
log_model: bool = False,
|
||||
experiment=None,
|
||||
**kwargs
|
||||
):
|
||||
if not _WANDB_AVAILABLE:
|
||||
raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover
|
||||
' install it with `pip install wandb`.')
|
||||
|
@ -91,13 +90,11 @@ class WandbLogger(LightningLoggerBase):
|
|||
self._save_dir = save_dir
|
||||
self._anonymous = 'allow' if anonymous else None
|
||||
self._id = version or id
|
||||
self._tags = tags
|
||||
self._project = project
|
||||
self._experiment = experiment
|
||||
self._offline = offline
|
||||
self._entity = entity
|
||||
self._log_model = log_model
|
||||
self._group = group
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
|
@ -126,8 +123,7 @@ class WandbLogger(LightningLoggerBase):
|
|||
os.environ['WANDB_MODE'] = 'dryrun'
|
||||
self._experiment = wandb.init(
|
||||
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
|
||||
reinit=True, id=self._id, resume='allow', tags=self._tags, entity=self._entity,
|
||||
group=self._group)
|
||||
reinit=True, id=self._id, resume='allow', **self._kwargs)
|
||||
# save checkpoints in wandb dir to upload on W&B servers
|
||||
if self._log_model:
|
||||
self._save_dir = self._experiment.dir
|
||||
|
|
|
@ -69,7 +69,7 @@ def test_comet_logger_experiment_name():
|
|||
|
||||
_ = logger.experiment
|
||||
|
||||
comet.assert_called_once_with(api_key=api_key, project_name=None, workspace=None)
|
||||
comet.assert_called_once_with(api_key=api_key, project_name=None)
|
||||
|
||||
comet().set_name.assert_called_once_with(experiment_name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue