Allow kwargs in Wandb & Neptune + kwargs docstring (#3475)

* Allow kwargs in WandbLogger

* isort

* kwargs docstring

* typo

* kwargs for other loggers

* pep and isort

* formatting

* fix failing test

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Rohit Gupta 2020-09-19 22:21:43 +05:30 committed by GitHub
parent 8eb77cd06a
commit 07b857769a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 108 deletions

View File

@ -18,13 +18,13 @@ import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple, MutableMapping
from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
class LightningLoggerBase(ABC):

View File

@ -18,15 +18,14 @@ Comet
"""
import os
from argparse import Namespace
from typing import Optional, Dict, Union, Any
from typing import Any, Dict, Optional, Union
try:
from comet_ml import Experiment as CometExperiment
from comet_ml import ExistingExperiment as CometExistingExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml import BaseExperiment as CometBaseExperiment
from comet_ml import ExistingExperiment as CometExistingExperiment
from comet_ml import Experiment as CometExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml import generate_guid
try:
@ -34,7 +33,7 @@ try:
except ImportError: # pragma: no-cover
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
from comet_ml.papi import API # pragma: no-cover
from comet_ml.config import get_config, get_api_key
from comet_ml.config import get_api_key, get_config
except ImportError: # pragma: no-cover
CometExperiment = None
CometExistingExperiment = None
@ -51,8 +50,8 @@ from torch import is_tensor
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class CometLogger(LightningLoggerBase):
@ -102,7 +101,6 @@ class CometLogger(LightningLoggerBase):
if either exists.
save_dir: Required in offline mode. The path for the directory to save local
comet logs. If given, this also sets the directory for saving checkpoints.
workspace: Optional. Name of workspace for this user
project_name: Optional. Send your experiment to a specific project.
Otherwise will be sent to Uncategorized Experiments.
If the project name does not already exist, Comet.ml will create a new project.
@ -114,21 +112,21 @@ class CometLogger(LightningLoggerBase):
the experiment will be in online or offline mode. This is useful if you use
save_dir to control the checkpoints directory and have a ~/.comet.config
file but still want to run offline experiments.
\**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by
:class:`CometExperiment` can be passed as keyword arguments in this logger.
"""
def __init__(
self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
offline: bool = False,
**kwargs,
**kwargs
):
if not _COMET_AVAILABLE:
raise ImportError(
"You want to use `comet_ml` logger which is not installed yet,"
@ -157,7 +155,6 @@ class CometLogger(LightningLoggerBase):
log.info(f"CometLogger will be initialized in {self.mode} mode")
self.workspace = workspace
self._project_name = project_name
self._experiment_key = experiment_key
self._experiment_name = experiment_name
@ -197,13 +194,14 @@ class CometLogger(LightningLoggerBase):
if self.mode == "online":
if self._experiment_key is None:
self._experiment = CometExperiment(
api_key=self.api_key, workspace=self.workspace, project_name=self._project_name, **self._kwargs
api_key=self.api_key,
project_name=self._project_name,
**self._kwargs,
)
self._experiment_key = self._experiment.get_key()
else:
self._experiment = CometExistingExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self._project_name,
previous_experiment=self._experiment_key,
**self._kwargs,
@ -211,7 +209,6 @@ class CometLogger(LightningLoggerBase):
else:
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self._project_name,
**self._kwargs,
)

View File

@ -23,14 +23,14 @@ import csv
import io
import os
from argparse import Namespace
from typing import Optional, Dict, Any, Union
from typing import Any, Dict, Optional, Union
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
class ExperimentWriter(object):
@ -116,11 +116,12 @@ class CSVLogger(LightningLoggerBase):
directory for existing versions, then automatically assigns the next available version.
"""
def __init__(self,
save_dir: str,
name: Optional[str] = "default",
version: Optional[Union[int, str]] = None):
def __init__(
self,
save_dir: str,
name: Optional[str] = "default",
version: Optional[Union[int, str]] = None
):
super().__init__()
self._save_dir = save_dir
self._name = name or ''

View File

@ -18,7 +18,7 @@ MLflow
"""
from argparse import Namespace
from time import time
from typing import Optional, Dict, Any, Union
from typing import Any, Dict, Optional, Union
try:
import mlflow
@ -34,7 +34,6 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only
LOCAL_FILE_URI_PREFIX = "file:"
@ -77,12 +76,13 @@ class MLFlowLogger(LightningLoggerBase):
"""
def __init__(self,
experiment_name: str = 'default',
tracking_uri: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = './mlruns'):
def __init__(
self,
experiment_name: str = 'default',
tracking_uri: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = './mlruns'
):
if not _MLFLOW_AVAILABLE:
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
' install it with `pip install mlflow`.')

View File

@ -17,8 +17,7 @@ Neptune
-------
"""
from argparse import Namespace
from typing import Optional, List, Dict, Any, Union, Iterable
from typing import Any, Dict, Iterable, List, Optional, Union
try:
import neptune
@ -159,41 +158,19 @@ class NeptuneLogger(LightningLoggerBase):
experiment_name: Optional. Editable name of the experiment.
Name is displayed in the experiments Details (Metadata section) and
in experiments view as a column.
upload_source_files: Optional. List of source files to be uploaded.
Must be list of str or single str. Uploaded sources are displayed
in the experiments Source code tab.
If ``None`` is passed, the Python file from which the experiment was created will be uploaded.
Pass an empty list (``[]``) to upload no files.
Unix style pathname pattern expansion is supported.
For example, you can pass ``'\*.py'``
to upload all python source files from the current directory.
For recursion lookup use ``'\**/\*.py'`` (for Python 3.5 and later).
For more information see :mod:`glob` library.
params: Optional. Parameters of the experiment.
After experiment creation params are read-only.
Parameters are displayed in the experiments Parameters section and
each key-value pair can be viewed in the experiments view as a column.
properties: Optional. Default is ``{}``. Properties of the experiment.
They are editable after the experiment is created.
Properties are displayed in the experiments Details section and
each key-value pair can be viewed in the experiments view as a column.
tags: Optional. Default is ``[]``. Must be list of str. Tags of the experiment.
They are editable after the experiment is created (see: ``append_tag()`` and ``remove_tag()``).
Tags are displayed in the experiments Details section and can be viewed
in the experiments view as a column.
\**kwargs: Additional arguments like `params`, `tags`, `properties`, etc. used by
:func:`neptune.Session.create_experiment` can be passed as keyword arguments in this logger.
"""
def __init__(self,
api_key: Optional[str] = None,
project_name: Optional[str] = None,
close_after_fit: Optional[bool] = True,
offline_mode: bool = False,
experiment_name: Optional[str] = None,
upload_source_files: Optional[List[str]] = None,
params: Optional[Dict[str, Any]] = None,
properties: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
**kwargs):
def __init__(
self,
api_key: Optional[str] = None,
project_name: Optional[str] = None,
close_after_fit: Optional[bool] = True,
offline_mode: bool = False,
experiment_name: Optional[str] = None,
**kwargs
):
if not _NEPTUNE_AVAILABLE:
raise ImportError('You want to use `neptune` logger which is not installed yet,'
' install it with `pip install neptune-client`.')
@ -203,10 +180,6 @@ class NeptuneLogger(LightningLoggerBase):
self.offline_mode = offline_mode
self.close_after_fit = close_after_fit
self.experiment_name = experiment_name
self.upload_source_files = upload_source_files
self.params = params
self.properties = properties
self.tags = tags
self._kwargs = kwargs
self._experiment_id = None
self._experiment = self._create_or_get_experiment()
@ -391,13 +364,7 @@ class NeptuneLogger(LightningLoggerBase):
project = session.get_project(self.project_name)
if self._experiment_id is None:
exp = project.create_experiment(
name=self.experiment_name,
params=self.params,
properties=self.properties,
tags=self.tags,
upload_source_files=self.upload_source_files,
**self._kwargs)
exp = project.create_experiment(name=self.experiment_name, **self._kwargs)
else:
exp = project.get_experiments(id=self._experiment_id)[0]

View File

@ -19,7 +19,7 @@ TensorBoard
import os
from argparse import Namespace
from typing import Optional, Dict, Union, Any
from typing import Any, Dict, Optional, Union
from warnings import warn
import torch
@ -27,11 +27,11 @@ from pkg_resources import parse_version
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.core.lightning import LightningModule
try:
from omegaconf import Container, OmegaConf
@ -67,7 +67,8 @@ class TensorBoardLogger(LightningLoggerBase):
model.
default_hp_metric: Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
\**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor.
\**kwargs: Additional arguments like `comment`, `filename_suffix`, etc. used by
:class:`SummaryWriter` can be passed as keyword arguments in this logger.
"""
NAME_HPARAMS_FILE = 'hparams.yaml'

View File

@ -17,7 +17,7 @@ Test Tube
---------
"""
from argparse import Namespace
from typing import Optional, Dict, Any, Union
from typing import Any, Dict, Optional, Union
try:
from test_tube import Experiment
@ -26,9 +26,9 @@ except ImportError: # pragma: no-cover
Experiment = None
_TEST_TUBE_AVAILABLE = False
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.core.lightning import LightningModule
class TestTubeLogger(LightningLoggerBase):

View File

@ -18,7 +18,7 @@ Weights and Biases
"""
import os
from argparse import Namespace
from typing import Optional, List, Dict, Union, Any
from typing import Any, Dict, List, Optional, Union
import torch.nn as nn
@ -36,7 +36,7 @@ from pytorch_lightning.utilities import rank_zero_only
class WandbLogger(LightningLoggerBase):
"""
r"""
Log using `Weights and Biases <https://www.wandb.com/>`_. Install it with pip:
.. code-block:: bash
@ -51,11 +51,10 @@ class WandbLogger(LightningLoggerBase):
anonymous: Enables or explicitly disables anonymous logging.
version: Sets the version, mainly used to resume a previous run.
project: The name of the project to which this run will belong.
tags: Tags associated with this run.
log_model: Save checkpoints in wandb dir to upload on W&B servers.
experiment: WandB experiment object
entity: The team posting this run (default: your username or your default team)
group: A unique string shared by all runs in a given group
experiment: WandB experiment object.
\**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by
:func:`wandb.init` can be passed as keyword arguments in this logger.
Example:
>>> from pytorch_lightning.loggers import WandbLogger
@ -70,19 +69,19 @@ class WandbLogger(LightningLoggerBase):
"""
def __init__(self,
name: Optional[str] = None,
save_dir: Optional[str] = None,
offline: bool = False,
id: Optional[str] = None,
anonymous: bool = False,
version: Optional[str] = None,
project: Optional[str] = None,
tags: Optional[List[str]] = None,
log_model: bool = False,
experiment=None,
entity=None,
group: Optional[str] = None):
def __init__(
self,
name: Optional[str] = None,
save_dir: Optional[str] = None,
offline: bool = False,
id: Optional[str] = None,
anonymous: bool = False,
version: Optional[str] = None,
project: Optional[str] = None,
log_model: bool = False,
experiment=None,
**kwargs
):
if not _WANDB_AVAILABLE:
raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover
' install it with `pip install wandb`.')
@ -91,13 +90,11 @@ class WandbLogger(LightningLoggerBase):
self._save_dir = save_dir
self._anonymous = 'allow' if anonymous else None
self._id = version or id
self._tags = tags
self._project = project
self._experiment = experiment
self._offline = offline
self._entity = entity
self._log_model = log_model
self._group = group
self._kwargs = kwargs
def __getstate__(self):
state = self.__dict__.copy()
@ -126,8 +123,7 @@ class WandbLogger(LightningLoggerBase):
os.environ['WANDB_MODE'] = 'dryrun'
self._experiment = wandb.init(
name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous,
reinit=True, id=self._id, resume='allow', tags=self._tags, entity=self._entity,
group=self._group)
reinit=True, id=self._id, resume='allow', **self._kwargs)
# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
self._save_dir = self._experiment.dir

View File

@ -69,7 +69,7 @@ def test_comet_logger_experiment_name():
_ = logger.experiment
comet.assert_called_once_with(api_key=api_key, project_name=None, workspace=None)
comet.assert_called_once_with(api_key=api_key, project_name=None)
comet().set_name.assert_called_once_with(experiment_name)