From 07b857769a3c9694791c34a46b4b25c5edf18a57 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Sat, 19 Sep 2020 22:21:43 +0530 Subject: [PATCH] Allow kwargs in Wandb & Neptune + kwargs docstring (#3475) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow kwargs in WandbLogger * isort * kwargs docstring * typo * kwargs for other loggers * pep and isort * formatting * fix failing test Co-authored-by: Adrian Wälchli --- pytorch_lightning/loggers/base.py | 4 +- pytorch_lightning/loggers/comet.py | 27 +++++------ pytorch_lightning/loggers/csv_logs.py | 15 +++--- pytorch_lightning/loggers/mlflow.py | 16 +++---- pytorch_lightning/loggers/neptune.py | 59 ++++++------------------ pytorch_lightning/loggers/tensorboard.py | 7 +-- pytorch_lightning/loggers/test_tube.py | 4 +- pytorch_lightning/loggers/wandb.py | 44 ++++++++---------- tests/loggers/test_comet.py | 2 +- 9 files changed, 70 insertions(+), 108 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index aa43c5155d..8f72830027 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -18,13 +18,13 @@ import operator from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Union, Optional, Dict, Iterable, Any, Callable, List, Sequence, Mapping, Tuple, MutableMapping +from typing import Any, Callable, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union import numpy as np import torch -from pytorch_lightning.utilities import rank_zero_only from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_only class LightningLoggerBase(ABC): diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index f0f52daba1..37a91bd98c 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -18,15 +18,14 @@ Comet """ import os - from argparse import Namespace -from typing import Optional, Dict, Union, Any +from typing import Any, Dict, Optional, Union try: - from comet_ml import Experiment as CometExperiment - from comet_ml import ExistingExperiment as CometExistingExperiment - from comet_ml import OfflineExperiment as CometOfflineExperiment from comet_ml import BaseExperiment as CometBaseExperiment + from comet_ml import ExistingExperiment as CometExistingExperiment + from comet_ml import Experiment as CometExperiment + from comet_ml import OfflineExperiment as CometOfflineExperiment from comet_ml import generate_guid try: @@ -34,7 +33,7 @@ try: except ImportError: # pragma: no-cover # For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300 from comet_ml.papi import API # pragma: no-cover - from comet_ml.config import get_config, get_api_key + from comet_ml.config import get_api_key, get_config except ImportError: # pragma: no-cover CometExperiment = None CometExistingExperiment = None @@ -51,8 +50,8 @@ from torch import is_tensor from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.exceptions import MisconfigurationException class CometLogger(LightningLoggerBase): @@ -102,7 +101,6 @@ class CometLogger(LightningLoggerBase): if either exists. save_dir: Required in offline mode. The path for the directory to save local comet logs. If given, this also sets the directory for saving checkpoints. - workspace: Optional. Name of workspace for this user project_name: Optional. Send your experiment to a specific project. Otherwise will be sent to Uncategorized Experiments. If the project name does not already exist, Comet.ml will create a new project. @@ -114,21 +112,21 @@ class CometLogger(LightningLoggerBase): the experiment will be in online or offline mode. This is useful if you use save_dir to control the checkpoints directory and have a ~/.comet.config file but still want to run offline experiments. + \**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + :class:`CometExperiment` can be passed as keyword arguments in this logger. """ def __init__( self, api_key: Optional[str] = None, save_dir: Optional[str] = None, - workspace: Optional[str] = None, project_name: Optional[str] = None, rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None, experiment_key: Optional[str] = None, offline: bool = False, - **kwargs, + **kwargs ): - if not _COMET_AVAILABLE: raise ImportError( "You want to use `comet_ml` logger which is not installed yet," @@ -157,7 +155,6 @@ class CometLogger(LightningLoggerBase): log.info(f"CometLogger will be initialized in {self.mode} mode") - self.workspace = workspace self._project_name = project_name self._experiment_key = experiment_key self._experiment_name = experiment_name @@ -197,13 +194,14 @@ class CometLogger(LightningLoggerBase): if self.mode == "online": if self._experiment_key is None: self._experiment = CometExperiment( - api_key=self.api_key, workspace=self.workspace, project_name=self._project_name, **self._kwargs + api_key=self.api_key, + project_name=self._project_name, + **self._kwargs, ) self._experiment_key = self._experiment.get_key() else: self._experiment = CometExistingExperiment( api_key=self.api_key, - workspace=self.workspace, project_name=self._project_name, previous_experiment=self._experiment_key, **self._kwargs, @@ -211,7 +209,6 @@ class CometLogger(LightningLoggerBase): else: self._experiment = CometOfflineExperiment( offline_directory=self.save_dir, - workspace=self.workspace, project_name=self._project_name, **self._kwargs, ) diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index 96e64c0d88..c22f46eb03 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -23,14 +23,14 @@ import csv import io import os from argparse import Namespace -from typing import Optional, Dict, Any, Union +from typing import Any, Dict, Optional, Union import torch from pytorch_lightning import _logger as log from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase -from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn class ExperimentWriter(object): @@ -116,11 +116,12 @@ class CSVLogger(LightningLoggerBase): directory for existing versions, then automatically assigns the next available version. """ - def __init__(self, - save_dir: str, - name: Optional[str] = "default", - version: Optional[Union[int, str]] = None): - + def __init__( + self, + save_dir: str, + name: Optional[str] = "default", + version: Optional[Union[int, str]] = None + ): super().__init__() self._save_dir = save_dir self._name = name or '' diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 83dd6f7481..5433ef9079 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -18,7 +18,7 @@ MLflow """ from argparse import Namespace from time import time -from typing import Optional, Dict, Any, Union +from typing import Any, Dict, Optional, Union try: import mlflow @@ -34,7 +34,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only - LOCAL_FILE_URI_PREFIX = "file:" @@ -77,12 +76,13 @@ class MLFlowLogger(LightningLoggerBase): """ - def __init__(self, - experiment_name: str = 'default', - tracking_uri: Optional[str] = None, - tags: Optional[Dict[str, Any]] = None, - save_dir: Optional[str] = './mlruns'): - + def __init__( + self, + experiment_name: str = 'default', + tracking_uri: Optional[str] = None, + tags: Optional[Dict[str, Any]] = None, + save_dir: Optional[str] = './mlruns' + ): if not _MLFLOW_AVAILABLE: raise ImportError('You want to use `mlflow` logger which is not installed yet,' ' install it with `pip install mlflow`.') diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 8aa148db25..486527cd59 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -17,8 +17,7 @@ Neptune ------- """ from argparse import Namespace -from typing import Optional, List, Dict, Any, Union, Iterable - +from typing import Any, Dict, Iterable, List, Optional, Union try: import neptune @@ -159,41 +158,19 @@ class NeptuneLogger(LightningLoggerBase): experiment_name: Optional. Editable name of the experiment. Name is displayed in the experiment’s Details (Metadata section) and in experiments view as a column. - upload_source_files: Optional. List of source files to be uploaded. - Must be list of str or single str. Uploaded sources are displayed - in the experiment’s Source code tab. - If ``None`` is passed, the Python file from which the experiment was created will be uploaded. - Pass an empty list (``[]``) to upload no files. - Unix style pathname pattern expansion is supported. - For example, you can pass ``'\*.py'`` - to upload all python source files from the current directory. - For recursion lookup use ``'\**/\*.py'`` (for Python 3.5 and later). - For more information see :mod:`glob` library. - params: Optional. Parameters of the experiment. - After experiment creation params are read-only. - Parameters are displayed in the experiment’s Parameters section and - each key-value pair can be viewed in the experiments view as a column. - properties: Optional. Default is ``{}``. Properties of the experiment. - They are editable after the experiment is created. - Properties are displayed in the experiment’s Details section and - each key-value pair can be viewed in the experiments view as a column. - tags: Optional. Default is ``[]``. Must be list of str. Tags of the experiment. - They are editable after the experiment is created (see: ``append_tag()`` and ``remove_tag()``). - Tags are displayed in the experiment’s Details section and can be viewed - in the experiments view as a column. + \**kwargs: Additional arguments like `params`, `tags`, `properties`, etc. used by + :func:`neptune.Session.create_experiment` can be passed as keyword arguments in this logger. """ - def __init__(self, - api_key: Optional[str] = None, - project_name: Optional[str] = None, - close_after_fit: Optional[bool] = True, - offline_mode: bool = False, - experiment_name: Optional[str] = None, - upload_source_files: Optional[List[str]] = None, - params: Optional[Dict[str, Any]] = None, - properties: Optional[Dict[str, Any]] = None, - tags: Optional[List[str]] = None, - **kwargs): + def __init__( + self, + api_key: Optional[str] = None, + project_name: Optional[str] = None, + close_after_fit: Optional[bool] = True, + offline_mode: bool = False, + experiment_name: Optional[str] = None, + **kwargs + ): if not _NEPTUNE_AVAILABLE: raise ImportError('You want to use `neptune` logger which is not installed yet,' ' install it with `pip install neptune-client`.') @@ -203,10 +180,6 @@ class NeptuneLogger(LightningLoggerBase): self.offline_mode = offline_mode self.close_after_fit = close_after_fit self.experiment_name = experiment_name - self.upload_source_files = upload_source_files - self.params = params - self.properties = properties - self.tags = tags self._kwargs = kwargs self._experiment_id = None self._experiment = self._create_or_get_experiment() @@ -391,13 +364,7 @@ class NeptuneLogger(LightningLoggerBase): project = session.get_project(self.project_name) if self._experiment_id is None: - exp = project.create_experiment( - name=self.experiment_name, - params=self.params, - properties=self.properties, - tags=self.tags, - upload_source_files=self.upload_source_files, - **self._kwargs) + exp = project.create_experiment(name=self.experiment_name, **self._kwargs) else: exp = project.get_experiments(id=self._experiment_id)[0] diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index d9c5125173..24ff137a51 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -19,7 +19,7 @@ TensorBoard import os from argparse import Namespace -from typing import Optional, Dict, Union, Any +from typing import Any, Dict, Optional, Union from warnings import warn import torch @@ -27,11 +27,11 @@ from pkg_resources import parse_version from torch.utils.tensorboard import SummaryWriter from pytorch_lightning import _logger as log +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.core.lightning import LightningModule try: from omegaconf import Container, OmegaConf @@ -67,7 +67,8 @@ class TensorBoardLogger(LightningLoggerBase): model. default_hp_metric: Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is called without a metric (otherwise calls to log_hyperparams without a metric are ignored). - \**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor. + \**kwargs: Additional arguments like `comment`, `filename_suffix`, etc. used by + :class:`SummaryWriter` can be passed as keyword arguments in this logger. """ NAME_HPARAMS_FILE = 'hparams.yaml' diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index a1adacf654..bc5f168aeb 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -17,7 +17,7 @@ Test Tube --------- """ from argparse import Namespace -from typing import Optional, Dict, Any, Union +from typing import Any, Dict, Optional, Union try: from test_tube import Experiment @@ -26,9 +26,9 @@ except ImportError: # pragma: no-cover Experiment = None _TEST_TUBE_AVAILABLE = False +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn -from pytorch_lightning.core.lightning import LightningModule class TestTubeLogger(LightningLoggerBase): diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 10588da6de..51ceafe614 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -18,7 +18,7 @@ Weights and Biases """ import os from argparse import Namespace -from typing import Optional, List, Dict, Union, Any +from typing import Any, Dict, List, Optional, Union import torch.nn as nn @@ -36,7 +36,7 @@ from pytorch_lightning.utilities import rank_zero_only class WandbLogger(LightningLoggerBase): - """ + r""" Log using `Weights and Biases `_. Install it with pip: .. code-block:: bash @@ -51,11 +51,10 @@ class WandbLogger(LightningLoggerBase): anonymous: Enables or explicitly disables anonymous logging. version: Sets the version, mainly used to resume a previous run. project: The name of the project to which this run will belong. - tags: Tags associated with this run. log_model: Save checkpoints in wandb dir to upload on W&B servers. - experiment: WandB experiment object - entity: The team posting this run (default: your username or your default team) - group: A unique string shared by all runs in a given group + experiment: WandB experiment object. + \**kwargs: Additional arguments like `entity`, `group`, `tags`, etc. used by + :func:`wandb.init` can be passed as keyword arguments in this logger. Example: >>> from pytorch_lightning.loggers import WandbLogger @@ -70,19 +69,19 @@ class WandbLogger(LightningLoggerBase): """ - def __init__(self, - name: Optional[str] = None, - save_dir: Optional[str] = None, - offline: bool = False, - id: Optional[str] = None, - anonymous: bool = False, - version: Optional[str] = None, - project: Optional[str] = None, - tags: Optional[List[str]] = None, - log_model: bool = False, - experiment=None, - entity=None, - group: Optional[str] = None): + def __init__( + self, + name: Optional[str] = None, + save_dir: Optional[str] = None, + offline: bool = False, + id: Optional[str] = None, + anonymous: bool = False, + version: Optional[str] = None, + project: Optional[str] = None, + log_model: bool = False, + experiment=None, + **kwargs + ): if not _WANDB_AVAILABLE: raise ImportError('You want to use `wandb` logger which is not installed yet,' # pragma: no-cover ' install it with `pip install wandb`.') @@ -91,13 +90,11 @@ class WandbLogger(LightningLoggerBase): self._save_dir = save_dir self._anonymous = 'allow' if anonymous else None self._id = version or id - self._tags = tags self._project = project self._experiment = experiment self._offline = offline - self._entity = entity self._log_model = log_model - self._group = group + self._kwargs = kwargs def __getstate__(self): state = self.__dict__.copy() @@ -126,8 +123,7 @@ class WandbLogger(LightningLoggerBase): os.environ['WANDB_MODE'] = 'dryrun' self._experiment = wandb.init( name=self._name, dir=self._save_dir, project=self._project, anonymous=self._anonymous, - reinit=True, id=self._id, resume='allow', tags=self._tags, entity=self._entity, - group=self._group) + reinit=True, id=self._id, resume='allow', **self._kwargs) # save checkpoints in wandb dir to upload on W&B servers if self._log_model: self._save_dir = self._experiment.dir diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index debb231745..16e8d8551b 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -69,7 +69,7 @@ def test_comet_logger_experiment_name(): _ = logger.experiment - comet.assert_called_once_with(api_key=api_key, project_name=None, workspace=None) + comet.assert_called_once_with(api_key=api_key, project_name=None) comet().set_name.assert_called_once_with(experiment_name)