Adapt `NeptuneLogger` to new `neptune-client` api (#6867)

* Initial split to NeptuneLegacyLogger and NeptuneLogger

* Adapt NeptuneLogger to neptune-pytorch-lightning repo version

* Fix stylecheck and tests

* Fix style and PR suggestions

* Expect Run object in NeptuneLogger.init

* Model checkpoint support and restructured tests

* Reformat code - use " instead of '

* Fix logging INTEGRATION_VERSION_KEY

* Update CHANGELOG.md

* Fix stylecheck

* Remove NeptuneLegacyLogger

* updated neptune-related docstrings

* PR suggestions

* update CODEOWERS file
* move import logic to imports.py
* minor neptune.py improvements

* formatting fixes and minor updates

* Fix generation of docs

* formatting fixes and minor updates

* fix

* PR fixes vol. 2

* define return type of _dict_paths method
* bump required version of `neptune-client`

* Enable log_* functions

* Update pytorch_lightning/loggers/neptune.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Revert "Enable log_* functions"

This reverts commit 050a436899b7f3582c0455dc27b171335b85a3a5.

* Make global helper lists internal

* Logger's `name` and `version` methods return proper results

* Initialize Run and its name and id at logger init level

* Make _init_run_instance static

* Add pre-commit hook code changes.

* Fix blacken-docs check

* Fix neptune doctests and test_all

* added docs comment about neptune-specific syntax

* added docs comment about neptune-specific syntax in the loggers.rst

* fix

* Add pickling test

* added myself to neptune codeowners

* Enable some of deprecated log_* functions

* Restore _run_instance for unpickled logger

* Add `step` parameter to log_* functions

* Fix stylecheck

* Fix checkstyle

* Fix checkstyle

* Update pytorch_lightning/loggers/neptune.py

Co-authored-by: thomas chaton <thomas@grid.ai>

* Fix tests

* Fix stylecheck

* fixed project name

* Fix windows tests

* Fix stylechecks

* Fix neptune docs tests

* docformatter fixes

* De-duplicate legacy_kwargs_msg

* Update .github/CODEOWNERS

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: kamil-kaczmarek <kamil.kaczmarek@neptune.ml>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Jakub Kuszneruk 2021-09-10 18:48:58 +02:00 committed by GitHub
parent ffd275f5b1
commit ee3787216a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 926 additions and 360 deletions

1
.gitignore vendored
View File

@ -147,6 +147,7 @@ wandb
.forked/
*.prof
*.tar.gz
.neptune/
# dataset generated from bolts in examples.
cifar-10-batches-py

View File

@ -124,6 +124,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Old [neptune-client](https://github.com/neptune-ai/neptune-client) API is supported by `NeptuneClient` from [neptune-contrib](https://github.com/neptune-ai/neptune-contrib) repo.
- Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770))

View File

@ -9,7 +9,7 @@
Loggers
*******
Lightning supports the most popular logging frameworks (TensorBoard, Comet, etc...). TensorBoard is used by default,
Lightning supports the most popular logging frameworks (TensorBoard, Comet, Neptune, etc...). TensorBoard is used by default,
but you can pass to the :class:`~pytorch_lightning.trainer.trainer.Trainer` any combination of the following loggers.
.. note::
@ -107,34 +107,50 @@ First, install the package:
pip install neptune-client
or with conda:
.. code-block:: bash
conda install -c conda-forge neptune-client
Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.trainer.Trainer`:
.. testcode::
.. code-block:: python
from pytorch_lightning.loggers import NeptuneLogger
neptune_logger = NeptuneLogger(
api_key="ANONYMOUS", # replace with your own
project_name="shared/pytorch-lightning-integration",
experiment_name="default", # Optional,
params={"max_epochs": 10}, # Optional,
tags=["pytorch-lightning", "mlp"], # Optional,
project="common/pytorch-lightning-integration", # format "<WORKSPACE/PROJECT>"
tags=["training", "resnet"], # optional
)
trainer = Trainer(logger=neptune_logger)
The :class:`~pytorch_lightning.loggers.NeptuneLogger` is available anywhere except ``__init__`` in your
:class:`~pytorch_lightning.core.lightning.LightningModule`.
.. testcode::
.. code-block:: python
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):
some_img = fake_image()
self.logger.experiment.add_image("generated_images", some_img, 0)
# generic recipe for logging custom metadata (neptune specific)
metadata = ...
self.logger.experiment["your/metadata/structure"].log(metadata)
Note that syntax: ``self.logger.experiment["your/metadata/structure"].log(metadata)``
is specific to Neptune and it extends logger capabilities.
Specifically, it allows you to log various types of metadata like scores, files,
images, interactive visuals, CSVs, etc. Refer to the
`Neptune docs <https://docs.neptune.ai/you-should-know/logging-metadata#essential-logging-methods>`_
for more detailed explanations.
You can always use regular logger methods: ``log_metrics()`` and ``log_hyperparams()`` as these are also supported.
.. seealso::
:class:`~pytorch_lightning.loggers.NeptuneLogger` docs.
Logger `user guide <https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning>`_.
----------------
Tensorboard
@ -227,7 +243,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
:class:`~pytorch_lightning.core.lightning.LightningModule`.
.. testcode::
.. code-block:: python
class MyModule(LightningModule):
def any_lightning_module_function_or_hook(self):

View File

@ -15,25 +15,73 @@
Neptune Logger
--------------
"""
__all__ = [
"NeptuneLogger",
]
import logging
import os
import warnings
from argparse import Namespace
from typing import Any, Dict, Iterable, Optional, Union
from functools import reduce
from typing import Any, Dict, Generator, Optional, Set, Union
from weakref import ReferenceType
import torch
from torch import is_tensor
from pytorch_lightning import __version__
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import _module_available, rank_zero_only
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
from pytorch_lightning.utilities.model_summary import ModelSummary
if _NEPTUNE_AVAILABLE and _NEPTUNE_GREATER_EQUAL_0_9:
try:
from neptune import new as neptune
from neptune.new.exceptions import NeptuneLegacyProjectException, NeptuneOfflineModeFetchException
from neptune.new.run import Run
from neptune.new.types import File as NeptuneFile
except ImportError:
import neptune
from neptune.exceptions import NeptuneLegacyProjectException
from neptune.run import Run
from neptune.types import File as NeptuneFile
else:
# needed for test mocks, and function signatures
neptune, Run, NeptuneFile = None, None, None
log = logging.getLogger(__name__)
_NEPTUNE_AVAILABLE = _module_available("neptune")
if _NEPTUNE_AVAILABLE:
import neptune
from neptune.experiments import Experiment
else:
# needed for test mocks, these tests shall be updated
neptune, Experiment = None, None
_INTEGRATION_VERSION_KEY = "source_code/integrations/pytorch-lightning"
# kwargs used in previous NeptuneLogger version, now deprecated
_LEGACY_NEPTUNE_INIT_KWARGS = [
"project_name",
"offline_mode",
"experiment_name",
"experiment_id",
"params",
"properties",
"upload_source_files",
"abort_callback",
"logger",
"upload_stdout",
"upload_stderr",
"send_hardware_metrics",
"run_monitoring_thread",
"handle_uncaught_exceptions",
"git_info",
"hostname",
"notebook_id",
"notebook_path",
]
# kwargs used in legacy NeptuneLogger from neptune-pytorch-lightning package
_LEGACY_NEPTUNE_LOGGER_KWARGS = [
"base_namespace",
"close_after_fit",
]
class NeptuneLogger(LightningLoggerBase):
@ -46,223 +94,394 @@ class NeptuneLogger(LightningLoggerBase):
pip install neptune-client
The Neptune logger can be used in the online mode or offline (silent) mode.
To log experiment data in online mode, :class:`NeptuneLogger` requires an API key.
In offline mode, the logger does not connect to Neptune.
or conda:
**ONLINE MODE**
.. code-block:: bash
conda install -c conda-forge neptune-client
**Quickstart**
Pass NeptuneLogger instance to the Trainer to log metadata with Neptune:
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import NeptuneLogger
neptune_logger = NeptuneLogger(
api_key="ANONYMOUS", # replace with your own
project="common/pytorch-lightning-integration", # format "<WORKSPACE/PROJECT>"
tags=["training", "resnet"], # optional
)
trainer = Trainer(max_epochs=10, logger=neptune_logger)
**How to use NeptuneLogger?**
Use the logger anywhere in your :class:`~pytorch_lightning.core.lightning.LightningModule` as follows:
.. code-block:: python
from neptune.new.types import File
from pytorch_lightning import LightningModule
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
# log metrics
acc = ...
self.log("train/loss", loss)
def any_lightning_module_function_or_hook(self):
# log images
img = ...
self.logger.experiment["train/misclassified_images"].log(File.as_image(img))
# generic recipe
metadata = ...
self.logger.experiment["your/metadata/structure"].log(metadata)
Note that syntax: ``self.logger.experiment["your/metadata/structure"].log(metadata)`` is specific to Neptune
and it extends logger capabilities. Specifically, it allows you to log various types of metadata
like scores, files, images, interactive visuals, CSVs, etc.
Refer to the `Neptune docs <https://docs.neptune.ai/you-should-know/logging-metadata#essential-logging-methods>`_
for more detailed explanations.
You can also use regular logger methods ``log_metrics()``, and ``log_hyperparams()`` with NeptuneLogger
as these are also supported.
**Log after fitting or testing is finished**
You can log objects after the fitting or testing methods are finished:
.. code-block:: python
neptune_logger = NeptuneLogger(project="common/pytorch-lightning-integration")
trainer = pl.Trainer(logger=neptune_logger)
model = ...
datamodule = ...
trainer.fit(model, datamodule=datamodule)
trainer.test(model, datamodule=datamodule)
# Log objects after `fit` or `test` methods
# model summary
neptune_logger.log_model_summary(model=model, max_depth=-1)
# generic recipe
metadata = ...
neptune_logger.experiment["your/metadata/structure"].log(metadata)
**Log model checkpoints**
If you have :class:`~pytorch_lightning.callbacks.ModelCheckpoint` configured,
Neptune logger automatically logs model checkpoints.
Model weights will be uploaded to the: "model/checkpoints" namespace in the Neptune Run.
You can disable this option:
.. code-block:: python
neptune_logger = NeptuneLogger(project="common/pytorch-lightning-integration", log_model_checkpoints=False)
**Pass additional parameters to the Neptune run**
You can also pass ``neptune_run_kwargs`` to specify the run in the greater detail, like ``tags`` or ``description``:
.. testcode::
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import NeptuneLogger
# arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class
# We are using an api_key for the anonymous user "neptuner" but you can use your own.
neptune_logger = NeptuneLogger(
api_key="ANONYMOUS",
project_name="shared/pytorch-lightning-integration",
experiment_name="default", # Optional,
params={"max_epochs": 10}, # Optional,
tags=["pytorch-lightning", "mlp"], # Optional,
project="common/pytorch-lightning-integration",
name="lightning-run",
description="mlp quick run with pytorch-lightning",
tags=["mlp", "quick-run"],
)
trainer = Trainer(max_epochs=10, logger=neptune_logger)
trainer = Trainer(max_epochs=3, logger=neptune_logger)
**OFFLINE MODE**
Check `run documentation <https://docs.neptune.ai/essentials/api-reference/run>`_
for more info about additional run parameters.
.. testcode::
**Details about Neptune run structure**
from pytorch_lightning.loggers import NeptuneLogger
Runs can be viewed as nested dictionary-like structures that you can define in your code.
Thanks to this you can easily organize your metadata in a way that is most convenient for you.
# arguments made to NeptuneLogger are passed on to the neptune.experiments.Experiment class
neptune_logger = NeptuneLogger(
offline_mode=True,
project_name="USER_NAME/PROJECT_NAME",
experiment_name="default", # Optional,
params={"max_epochs": 10}, # Optional,
tags=["pytorch-lightning", "mlp"], # Optional,
)
trainer = Trainer(max_epochs=10, logger=neptune_logger)
The hierarchical structure that you apply to your metadata will be reflected later in the UI.
Use the logger anywhere in you :class:`~pytorch_lightning.core.lightning.LightningModule` as follows:
.. code-block:: python
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
# log metrics
self.logger.experiment.log_metric("acc_train", ...)
# log images
self.logger.experiment.log_image("worse_predictions", ...)
# log model checkpoint
self.logger.experiment.log_artifact("model_checkpoint.pt", ...)
self.logger.experiment.whatever_neptune_supports(...)
def any_lightning_module_function_or_hook(self):
self.logger.experiment.log_metric("acc_train", ...)
self.logger.experiment.log_image("worse_predictions", ...)
self.logger.experiment.log_artifact("model_checkpoint.pt", ...)
self.logger.experiment.whatever_neptune_supports(...)
If you want to log objects after the training is finished use ``close_after_fit=False``:
.. code-block:: python
neptune_logger = NeptuneLogger(..., close_after_fit=False, ...)
trainer = Trainer(logger=neptune_logger)
trainer.fit()
# Log test metrics
trainer.test(model)
# Log additional metrics
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_true, y_pred)
neptune_logger.experiment.log_metric("test_accuracy", accuracy)
# Log charts
from scikitplot.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(16, 12))
plot_confusion_matrix(y_true, y_pred, ax=ax)
neptune_logger.experiment.log_image("confusion_matrix", fig)
# Save checkpoints folder
neptune_logger.experiment.log_artifact("my/checkpoints")
# When you are done, stop the experiment
neptune_logger.experiment.stop()
You can organize this way any type of metadata - images, parameters, metrics, model checkpoint, CSV files, etc.
See Also:
- An `Example experiment <https://ui.neptune.ai/o/shared/org/
pytorch-lightning-integration/e/PYTOR-66/charts>`_ showing the UI of Neptune.
- `Tutorial <https://docs.neptune.ai/integrations/pytorch_lightning.html>`_ on how to use
Pytorch Lightning with Neptune.
- Read about
`what object you can log to Neptune <https://docs.neptune.ai/you-should-know/what-can-you-log-and-display>`_.
- Check `example run <https://app.neptune.ai/o/common/org/pytorch-lightning-integration/e/PTL-1/all>`_
with multiple types of metadata logged.
- For more detailed info check
`user guide <https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning>`_.
Args:
api_key: Required in online mode.
Neptune API token, found on https://neptune.ai.
Read how to get your
`API key <https://docs.neptune.ai/python-api/tutorials/get-started.html#copy-api-token>`_.
api_key: Optional.
Neptune API token, found on https://neptune.ai upon registration.
Read: `how to find and set Neptune API token <https://docs.neptune.ai/administration/security-and-privacy/
how-to-find-and-set-neptune-api-token>`_.
It is recommended to keep it in the `NEPTUNE_API_TOKEN`
environment variable and then you can leave ``api_key=None``.
project_name: Required in online mode. Qualified name of a project in a form of
"namespace/project_name" for example "tom/minst-classification".
environment variable and then you can drop ``api_key=None``.
project: Optional.
Name of a project in a form of "my_workspace/my_project" for example "tom/mask-rcnn".
If ``None``, the value of `NEPTUNE_PROJECT` environment variable will be taken.
You need to create the project in https://neptune.ai first.
offline_mode: Optional default ``False``. If ``True`` no logs will be sent
to Neptune. Usually used for debug purposes.
close_after_fit: Optional default ``True``. If ``False`` the experiment
will not be closed after training and additional metrics,
images or artifacts can be logged. Also, remember to close the experiment explicitly
by running ``neptune_logger.experiment.stop()``.
experiment_name: Optional. Editable name of the experiment.
Name is displayed in the experiments Details (Metadata section) and
in experiments view as a column.
experiment_id: Optional. Default is ``None``. The ID of the existing experiment.
If specified, connect to experiment with experiment_id in project_name.
Input arguments "experiment_name", "params", "properties" and "tags" will be overriden based
on fetched experiment data.
prefix: A string to put at the beginning of metric keys.
\**kwargs: Additional arguments like `params`, `tags`, `properties`, etc. used by
:func:`neptune.Session.create_experiment` can be passed as keyword arguments in this logger.
name: Optional. Editable name of the run.
Run name appears in the "all metadata/sys" section in Neptune UI.
run: Optional. Default is ``None``. The Neptune ``Run`` object.
If specified, this `Run`` will be used for logging, instead of a new Run.
When run object is passed you can't specify other neptune properties.
log_model_checkpoints: Optional. Default is ``True``. Log model checkpoint to Neptune.
Works only if ``ModelCheckpoint`` is passed to the ``Trainer``.
prefix: Optional. Default is ``"training"``. Root namespace for all metadata logging.
\**neptune_run_kwargs: Additional arguments like ``tags``, ``description``, ``capture_stdout``, etc.
used when run is created.
Raises:
ImportError:
If required Neptune package is not installed on the device.
If required Neptune package in version >=0.9 is not installed on the device.
TypeError:
If configured project has not been migrated to new structure yet.
ValueError:
If argument passed to the logger's constructor is incorrect.
"""
LOGGER_JOIN_CHAR = "-"
LOGGER_JOIN_CHAR = "/"
PARAMETERS_KEY = "hyperparams"
ARTIFACTS_KEY = "artifacts"
def __init__(
self,
*, # force users to call `NeptuneLogger` initializer with `kwargs`
api_key: Optional[str] = None,
project_name: Optional[str] = None,
close_after_fit: Optional[bool] = True,
offline_mode: bool = False,
experiment_name: Optional[str] = None,
experiment_id: Optional[str] = None,
prefix: str = "",
**kwargs,
project: Optional[str] = None,
name: Optional[str] = None,
run: Optional["Run"] = None,
log_model_checkpoints: Optional[bool] = True,
prefix: str = "training",
**neptune_run_kwargs,
):
if neptune is None:
raise ImportError(
"You want to use `neptune` logger which is not installed yet,"
" install it with `pip install neptune-client`."
)
super().__init__()
self.api_key = api_key
self.project_name = project_name
self.offline_mode = offline_mode
self.close_after_fit = close_after_fit
self.experiment_name = experiment_name
self._prefix = prefix
self._kwargs = kwargs
self.experiment_id = experiment_id
self._experiment = None
log.info(f'NeptuneLogger will work in {"offline" if self.offline_mode else "online"} mode')
# verify if user passed proper init arguments
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
super().__init__()
self._log_model_checkpoints = log_model_checkpoints
self._prefix = prefix
self._run_instance = self._init_run_instance(api_key, project, name, run, neptune_run_kwargs)
self._run_short_id = self.run._short_id # skipcq: PYL-W0212
try:
self.run.wait()
self._run_name = self._run_instance["sys/name"].fetch()
except NeptuneOfflineModeFetchException:
self._run_name = "offline-name"
def _init_run_instance(self, api_key, project, name, run, neptune_run_kwargs) -> Run:
if run is not None:
run_instance = run
else:
try:
run_instance = neptune.init(
project=project,
api_token=api_key,
name=name,
**neptune_run_kwargs,
)
except NeptuneLegacyProjectException as e:
raise TypeError(
f"""Project {project} has not been migrated to the new structure.
You can still integrate it with the Neptune logger using legacy Python API
available as part of neptune-contrib package:
- https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n
"""
) from e
# make sure that we've log integration version for both newly created and outside `Run` instances
run_instance[_INTEGRATION_VERSION_KEY] = __version__
# keep api_key and project, they will be required when resuming Run for pickled logger
self._api_key = api_key
self._project_name = run_instance._project_name # skipcq: PYL-W0212
return run_instance
def _construct_path_with_prefix(self, *keys) -> str:
"""Return sequence of keys joined by `LOGGER_JOIN_CHAR`, started with `_prefix` if defined."""
if self._prefix:
return self.LOGGER_JOIN_CHAR.join([self._prefix, *keys])
return self.LOGGER_JOIN_CHAR.join(keys)
@staticmethod
def _verify_input_arguments(
api_key: Optional[str],
project: Optional[str],
name: Optional[str],
run: Optional["Run"],
neptune_run_kwargs: dict,
):
legacy_kwargs_msg = (
"Following kwargs are deprecated: {legacy_kwargs}.\n"
"If you are looking for the Neptune logger using legacy Python API,"
" it's still available as part of neptune-contrib package:\n"
" - https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n"
"The NeptuneLogger was re-written to use the neptune.new Python API\n"
" - https://neptune.ai/blog/neptune-new\n"
" - https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning\n"
"You should use arguments accepted by either NeptuneLogger.init() or neptune.init()"
)
# check if user used legacy kwargs expected in `NeptuneLegacyLogger`
used_legacy_kwargs = [
legacy_kwarg for legacy_kwarg in neptune_run_kwargs if legacy_kwarg in _LEGACY_NEPTUNE_INIT_KWARGS
]
if used_legacy_kwargs:
raise ValueError(legacy_kwargs_msg.format(legacy_kwargs=used_legacy_kwargs))
# check if user used legacy kwargs expected in `NeptuneLogger` from neptune-pytorch-lightning package
used_legacy_neptune_kwargs = [
legacy_kwarg for legacy_kwarg in neptune_run_kwargs if legacy_kwarg in _LEGACY_NEPTUNE_LOGGER_KWARGS
]
if used_legacy_neptune_kwargs:
raise ValueError(legacy_kwargs_msg.format(legacy_kwargs=used_legacy_neptune_kwargs))
# check if user passed new client `Run` object
if run is not None and not isinstance(run, Run):
raise ValueError(
"Run parameter expected to be of type `neptune.new.Run`.\n"
"If you are looking for the Neptune logger using legacy Python API,"
" it's still available as part of neptune-contrib package:\n"
" - https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n"
"The NeptuneLogger was re-written to use the neptune.new Python API\n"
" - https://neptune.ai/blog/neptune-new\n"
" - https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning\n"
)
# check if user passed redundant neptune.init arguments when passed run
any_neptune_init_arg_passed = any(arg is not None for arg in [api_key, project, name]) or neptune_run_kwargs
if run is not None and any_neptune_init_arg_passed:
raise ValueError(
"When an already initialized run object is provided"
" you can't provide other neptune.init() parameters.\n"
)
def __getstate__(self):
state = self.__dict__.copy()
# Experiment cannot be pickled, and additionally its ID cannot be pickled in offline mode
state["_experiment"] = None
if self.offline_mode:
state["experiment_id"] = None
# Run instance can't be pickled
state["_run_instance"] = None
return state
def __setstate__(self, state):
self.__dict__ = state
self._run_instance = neptune.init(project=self._project_name, api_token=self._api_key, run=self._run_short_id)
@property
@rank_zero_experiment
def experiment(self) -> Experiment:
def experiment(self) -> Run:
r"""
Actual Neptune object. To use neptune features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
Actual Neptune run object. Allows you to use neptune logging features in your
:class:`~pytorch_lightning.core.lightning.LightningModule`.
Example::
self.logger.experiment.some_neptune_function()
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
# log metrics
acc = ...
self.logger.experiment["train/acc"].log(acc)
# log images
img = ...
self.logger.experiment["train/misclassified_images"].log(File.as_image(img))
Note that syntax: ``self.logger.experiment["your/metadata/structure"].log(metadata)``
is specific to Neptune and it extends logger capabilities.
Specifically, it allows you to log various types of metadata like scores, files,
images, interactive visuals, CSVs, etc. Refer to the
`Neptune docs <https://docs.neptune.ai/you-should-know/logging-metadata#essential-logging-methods>`_
for more detailed explanations.
You can also use regular logger methods ``log_metrics()``, and ``log_hyperparams()``
with NeptuneLogger as these are also supported.
"""
return self.run
# Note that even though we initialize self._experiment in __init__,
# it may still end up being None after being pickled and un-pickled
if self._experiment is None:
self._experiment = self._create_or_get_experiment()
return self._experiment
@property
def run(self) -> Run:
return self._run_instance
@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # skipcq: PYL-W0221
r"""
Log hyper-parameters to the run.
Hyperparams will be logged under the "<prefix>/hyperparams" namespace.
Note:
You can also log parameters by directly using the logger instance:
``neptune_logger.experiment["model/hyper-parameters"] = params_dict``.
In this way you can keep hierarchical structure of the parameters.
Args:
params: `dict`.
Python dictionary structure with parameters.
Example::
from pytorch_lightning.loggers import NeptuneLogger
PARAMS = {
"batch_size": 64,
"lr": 0.07,
"decay_factor": 0.97
}
neptune_logger = NeptuneLogger(
api_key="ANONYMOUS",
project="common/pytorch-lightning-integration"
)
neptune_logger.log_hyperparams(PARAMS)
"""
params = self._convert_params(params)
params = self._flatten_dict(params)
for key, val in params.items():
self.experiment.set_property(f"param__{key}", val)
params = self._sanitize_callable_params(params)
parameters_key = self.PARAMETERS_KEY
parameters_key = self._construct_path_with_prefix(parameters_key)
self.run[parameters_key] = params
@rank_zero_only
def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None:
"""Log metrics (numeric values) in Neptune experiments.
"""Log metrics (numeric values) in Neptune runs.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded, currently ignored
metrics: Dictionary with metric names as keys and measured quantities as values.
step: Step number at which the metrics should be recorded, currently ignored.
"""
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
if rank_zero_only.rank != 0:
raise ValueError("run tried to log from global_rank != 0")
metrics = self._add_prefix(metrics)
for key, val in metrics.items():
# `step` is ignored because Neptune expects strictly increasing step values which
# Lighting does not always guarantee.
self.log_metric(key, val)
self.experiment[key].log(val)
@rank_zero_only
def finalize(self, status: str) -> None:
if status:
self.experiment[self._construct_path_with_prefix("status")] = status
super().finalize(status)
if self.close_after_fit:
self.experiment.stop()
@property
def save_dir(self) -> Optional[str]:
@ -270,131 +489,153 @@ class NeptuneLogger(LightningLoggerBase):
locally.
Returns:
None
the root directory where experiment logs get saved
"""
# Neptune does not save any local files
return None
return os.path.join(os.getcwd(), ".neptune")
def log_model_summary(self, model, max_depth=-1):
model_str = str(ModelSummary(model=model, max_depth=max_depth))
self.experiment[self._construct_path_with_prefix("model/summary")] = neptune.types.File.from_content(
content=model_str, extension="txt"
)
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
Args:
checkpoint_callback: the model checkpoint callback instance
"""
if not self._log_model_checkpoints:
return
file_names = set()
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
# save last model
if checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
self.experiment[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
# save best k models
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.experiment[f"{checkpoints_namespace}/{model_name}"].upload(key)
# remove old models logged to experiment if they are not part of best k models at this point
if self.experiment.exists(checkpoints_namespace):
exp_structure = self.experiment.get_structure()
uploaded_model_names = self._get_full_model_names_from_exp_structure(exp_structure, checkpoints_namespace)
for file_to_drop in list(uploaded_model_names - file_names):
del self.experiment[f"{checkpoints_namespace}/{file_to_drop}"]
# log best model path and best model score
if checkpoint_callback.best_model_path:
self.experiment[
self._construct_path_with_prefix("model/best_model_path")
] = checkpoint_callback.best_model_path
if checkpoint_callback.best_model_score:
self.experiment[self._construct_path_with_prefix("model/best_model_score")] = (
checkpoint_callback.best_model_score.cpu().detach().numpy()
)
@staticmethod
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str:
"""Returns model name which is string `modle_path` appended to `checkpoint_callback.dirpath`."""
expected_model_path = f"{checkpoint_callback.dirpath}/"
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
return model_path[len(expected_model_path) :]
@classmethod
def _get_full_model_names_from_exp_structure(cls, exp_structure: dict, namespace: str) -> Set[str]:
"""Returns all paths to properties which were already logged in `namespace`"""
structure_keys = namespace.split(cls.LOGGER_JOIN_CHAR)
uploaded_models_dict = reduce(lambda d, k: d[k], [exp_structure, *structure_keys])
return set(cls._dict_paths(uploaded_models_dict))
@classmethod
def _dict_paths(cls, d: dict, path_in_build: str = None) -> Generator:
for k, v in d.items():
path = f"{path_in_build}/{k}" if path_in_build is not None else k
if not isinstance(v, dict):
yield path
else:
yield from cls._dict_paths(v, path)
@property
def name(self) -> str:
"""Gets the name of the experiment.
Returns:
The name of the experiment if not in offline mode else "offline-name".
"""
if self.offline_mode:
return "offline-name"
return self.experiment.name
"""Return the experiment name or 'offline-name' when exp is run in offline mode."""
return self._run_name
@property
def version(self) -> str:
"""Gets the id of the experiment.
"""Return the experiment version.
Returns:
The id of the experiment if not in offline mode else "offline-id-1234".
It's Neptune Run's short_id
"""
if self.offline_mode:
return "offline-id-1234"
return self.experiment.id
return self._run_short_id
@staticmethod
def _signal_deprecated_api_usage(f_name, sample_code, raise_exception=False):
msg_suffix = (
f"If you are looking for the Neptune logger using legacy Python API,"
f" it's still available as part of neptune-contrib package:\n"
f" - https://docs-legacy.neptune.ai/integrations/pytorch_lightning.html\n"
f"The NeptuneLogger was re-written to use the neptune.new Python API\n"
f" - https://neptune.ai/blog/neptune-new\n"
f" - https://docs.neptune.ai/integrations-and-supported-tools/model-training/pytorch-lightning\n"
f"Instead of `logger.{f_name}` you can use:\n"
f"\t{sample_code}"
)
if not raise_exception:
warnings.warn(
"The function you've used is deprecated in v1.5.0 and will be removed in v1.7.0. " + msg_suffix
)
else:
raise ValueError("The function you've used is deprecated.\n" + msg_suffix)
@rank_zero_only
def log_metric(
self, metric_name: str, metric_value: Union[torch.Tensor, float, str], step: Optional[int] = None
) -> None:
"""Log metrics (numeric values) in Neptune experiments.
Args:
metric_name: The name of log, i.e. mse, loss, accuracy.
metric_value: The value of the log (data-point).
step: Step number at which the metrics should be recorded, must be strictly increasing
"""
if is_tensor(metric_value):
def log_metric(self, metric_name: str, metric_value: Union[torch.Tensor, float, str], step: Optional[int] = None):
key = f"{self._prefix}/{metric_name}"
self._signal_deprecated_api_usage("log_metric", f"logger.run['{key}'].log(42)")
if torch.is_tensor(metric_value):
metric_value = metric_value.cpu().detach()
if step is None:
self.experiment.log_metric(metric_name, metric_value)
else:
self.experiment.log_metric(metric_name, x=step, y=metric_value)
self.run[key].log(metric_value, step=step)
@rank_zero_only
def log_text(self, log_name: str, text: str, step: Optional[int] = None) -> None:
"""Log text data in Neptune experiments.
Args:
log_name: The name of log, i.e. mse, my_text_data, timing_info.
text: The value of the log (data-point).
step: Step number at which the metrics should be recorded, must be strictly increasing
"""
if step is None:
self.experiment.log_text(log_name, text)
else:
self.experiment.log_text(log_name, x=step, y=text)
key = f"{self._prefix}/{log_name}"
self._signal_deprecated_api_usage("log_text", f"logger.run['{key}].log('text')")
self.run[key].log(str(text), step=step)
@rank_zero_only
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None) -> None:
"""Log image data in Neptune experiment.
Args:
log_name: The name of log, i.e. bboxes, visualisations, sample_images.
image: The value of the log (data-point).
Can be one of the following types: PIL image, `matplotlib.figure.Figure`,
path to image file (str)
step: Step number at which the metrics should be recorded, must be strictly increasing
"""
if step is None:
self.experiment.log_image(log_name, image)
else:
self.experiment.log_image(log_name, x=step, y=image)
key = f"{self._prefix}/{log_name}"
self._signal_deprecated_api_usage("log_image", f"logger.run['{key}'].log(File('path_to_image'))")
if isinstance(image, str):
# if `img` is path to file, convert it to file object
image = NeptuneFile(image)
self.run[key].log(image, step=step)
@rank_zero_only
def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None:
"""Save an artifact (file) in Neptune experiment storage.
Args:
artifact: A path to the file in local filesystem.
destination: Optional. Default is ``None``. A destination path.
If ``None`` is passed, an artifact file name will be used.
"""
self.experiment.log_artifact(artifact, destination)
key = f"{self._prefix}/{self.ARTIFACTS_KEY}/{artifact}"
self._signal_deprecated_api_usage("log_artifact", f"logger.run['{key}].log('path_to_file')")
self.run[key].log(destination)
@rank_zero_only
def set_property(self, key: str, value: Any) -> None:
"""Set key-value pair as Neptune experiment property.
Args:
key: Property key.
value: New value of a property.
"""
self.experiment.set_property(key, value)
def set_property(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"log_artifact", f"logger.run['{self._prefix}/{self.PARAMETERS_KEY}/key'].log(value)", raise_exception=True
)
@rank_zero_only
def append_tags(self, tags: Union[str, Iterable[str]]) -> None:
"""Appends tags to the neptune experiment.
Args:
tags: Tags to add to the current experiment. If str is passed, a single tag is added.
If multiple - comma separated - str are passed, all of them are added as tags.
If list of str is passed, all elements of the list are added as tags.
"""
if str(tags) == tags:
tags = [tags] # make it as an iterable is if it is not yet
self.experiment.append_tags(*tags)
def _create_or_get_experiment(self):
if self.offline_mode:
project = neptune.Session(backend=neptune.OfflineBackend()).get_project("dry-run/project")
else:
session = neptune.Session.with_default_backend(api_token=self.api_key)
project = session.get_project(self.project_name)
if self.experiment_id is None:
exp = project.create_experiment(name=self.experiment_name, **self._kwargs)
self.experiment_id = exp.id
else:
exp = project.get_experiments(id=self.experiment_id)[0]
self.experiment_name = exp.get_system_properties()["name"]
self.params = exp.get_parameters()
self.properties = exp.get_properties()
self.tags = exp.get_tags()
return exp
def append_tags(self, *args, **kwargs):
self._signal_deprecated_api_usage(
"append_tags", "logger.run['sys/tags'].add(['foo', 'bar'])", raise_exception=True
)

View File

@ -81,6 +81,8 @@ _HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")
_KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available()
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_NEPTUNE_AVAILABLE = _module_available("neptune")
_NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_POPTORCH_AVAILABLE = _module_available("poptorch")
_RICH_AVAILABLE = _module_available("rich")

View File

@ -1,5 +1,5 @@
# all supported loggers
neptune-client>=0.4.109
neptune-client>=0.10.0
comet-ml>=3.1.12
mlflow>=1.0.0
test_tube>=0.7.5

View File

@ -36,6 +36,7 @@ from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
from tests.loggers.test_comet import _patch_comet_atexit
from tests.loggers.test_mlflow import mock_mlflow_run_creation
from tests.loggers.test_neptune import create_neptune_mock
def _get_logger_args(logger_class, save_dir):
@ -72,7 +73,7 @@ def test_loggers_fit_test_all(tmpdir, monkeypatch):
):
_test_loggers_fit_test(tmpdir, MLFlowLogger)
with mock.patch("pytorch_lightning.loggers.neptune.neptune"):
with mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock):
_test_loggers_fit_test(tmpdir, NeptuneLogger)
with mock.patch("pytorch_lightning.loggers.test_tube.Experiment"):
@ -233,10 +234,10 @@ def _test_loggers_save_dir_and_weights_save_path(tmpdir, logger_class):
CometLogger,
CSVLogger,
MLFlowLogger,
NeptuneLogger,
TensorBoardLogger,
TestTubeLogger,
# The WandbLogger gets tested for pickling in its own test.
# The NeptuneLogger gets tested for pickling in its own test.
],
)
def test_loggers_pickle_all(tmpdir, monkeypatch, logger_class):
@ -316,9 +317,7 @@ class RankZeroLoggerCheck(Callback):
assert pl_module.logger.experiment.something(foo="bar") is None
@pytest.mark.parametrize(
"logger_class", [CometLogger, CSVLogger, MLFlowLogger, NeptuneLogger, TensorBoardLogger, TestTubeLogger]
)
@pytest.mark.parametrize("logger_class", [CometLogger, CSVLogger, MLFlowLogger, TensorBoardLogger, TestTubeLogger])
@RunIf(skip_windows=True)
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
"""Test that loggers get replaced by dummy loggers on global rank > 0."""
@ -369,9 +368,12 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
# Neptune
with mock.patch("pytorch_lightning.loggers.neptune.neptune"):
logger = _instantiate_logger(NeptuneLogger, save_dir=tmpdir, prefix=prefix)
logger = _instantiate_logger(NeptuneLogger, api_key="test", project="project", save_dir=tmpdir, prefix=prefix)
assert logger.experiment.__getitem__.call_count == 1
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log_metric.assert_called_once_with("tmp-test", 1.0)
assert logger.experiment.__getitem__.call_count == 2
logger.experiment.__getitem__.assert_called_with("tmp/test")
logger.experiment.__getitem__().log.assert_called_once_with(1.0)
# TensorBoard
with mock.patch("pytorch_lightning.loggers.tensorboard.SummaryWriter"):

View File

@ -11,111 +11,410 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import os
import pickle
import unittest
from collections import namedtuple
from unittest.mock import call, MagicMock, patch
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning import __version__, Trainer
from pytorch_lightning.loggers import NeptuneLogger
from tests.helpers import BoringModel
@patch("pytorch_lightning.loggers.neptune.neptune")
def test_neptune_online(neptune):
logger = NeptuneLogger(api_key="test", project_name="project")
def create_neptune_mock():
"""Mock with provides nice `logger.name` and `logger.version` values.
created_experiment = neptune.Session.with_default_backend().get_project().create_experiment()
# It's important to check if the internal variable _experiment was initialized in __init__.
# Calling logger.experiment would cause a side-effect of initializing _experiment,
# if it wasn't already initialized.
assert logger._experiment is None
_ = logger.experiment
assert logger._experiment == created_experiment
assert logger.name == created_experiment.name
assert logger.version == created_experiment.id
Mostly due to fact, that windows tests were failing with MagicMock based strings, which were used to create local
directories in FS.
"""
return MagicMock(
init=MagicMock(
return_value=MagicMock(
__getitem__=MagicMock(return_value=MagicMock(fetch=MagicMock(return_value="Run test name"))),
_short_id="TEST-1",
)
)
)
@patch("pytorch_lightning.loggers.neptune.neptune")
def test_neptune_existing_experiment(neptune):
logger = NeptuneLogger(experiment_id="TEST-123")
neptune.Session.with_default_backend().get_project().get_experiments.assert_not_called()
experiment = logger.experiment
neptune.Session.with_default_backend().get_project().get_experiments.assert_called_once_with(id="TEST-123")
assert logger.experiment_name == experiment.get_system_properties()["name"]
assert logger.params == experiment.get_parameters()
assert logger.properties == experiment.get_properties()
assert logger.tags == experiment.get_tags()
class Run:
_short_id = "TEST-42"
_project_name = "test-project"
def __setitem__(self, key, value):
# called once
assert key == "source_code/integrations/pytorch-lightning"
assert value == __version__
def wait(self):
# for test purposes
pass
def __getitem__(self, item):
# called once
assert item == "sys/name"
return MagicMock(fetch=MagicMock(return_value="Test name"))
def __getstate__(self):
raise pickle.PicklingError("Runs are unpickleable")
@patch("pytorch_lightning.loggers.neptune.neptune")
def test_neptune_offline(neptune):
logger = NeptuneLogger(offline_mode=True)
neptune.Session.assert_not_called()
_ = logger.experiment
neptune.Session.assert_called_once_with(backend=neptune.OfflineBackend())
assert logger.experiment == neptune.Session().get_project().create_experiment()
@pytest.fixture
def tmpdir_unittest_fixture(request, tmpdir):
"""Proxy for pytest `tmpdir` fixture between pytest and unittest.
Resources:
* https://docs.pytest.org/en/6.2.x/tmpdir.html#the-tmpdir-fixture
* https://towardsdatascience.com/mixing-pytest-fixture-and-unittest-testcase-for-selenium-test-9162218e8c8e
"""
request.cls.tmpdir = tmpdir
@patch("pytorch_lightning.loggers.neptune.neptune")
def test_neptune_additional_methods(neptune):
logger = NeptuneLogger(api_key="test", project_name="project")
@patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock)
class TestNeptuneLogger(unittest.TestCase):
def test_neptune_online(self, neptune):
logger = NeptuneLogger(api_key="test", project="project")
created_run_mock = logger._run_instance
created_experiment = neptune.Session.with_default_backend().get_project().create_experiment()
self.assertEqual(logger._run_instance, created_run_mock)
self.assertEqual(logger.name, "Run test name")
self.assertEqual(logger.version, "TEST-1")
self.assertEqual(neptune.init.call_count, 1)
self.assertEqual(created_run_mock.__getitem__.call_count, 1)
self.assertEqual(created_run_mock.__setitem__.call_count, 1)
created_run_mock.__getitem__.assert_called_once_with(
"sys/name",
)
created_run_mock.__setitem__.assert_called_once_with("source_code/integrations/pytorch-lightning", __version__)
logger.log_metric("test", torch.ones(1))
created_experiment.log_metric.assert_called_once_with("test", torch.ones(1))
created_experiment.log_metric.reset_mock()
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_online_with_custom_run(self, neptune):
created_run = Run()
logger = NeptuneLogger(run=created_run)
logger.log_metric("test", 1.0)
created_experiment.log_metric.assert_called_once_with("test", 1.0)
created_experiment.log_metric.reset_mock()
assert logger._run_instance == created_run
self.assertEqual(logger._run_instance, created_run)
self.assertEqual(logger.version, created_run._short_id)
self.assertEqual(neptune.init.call_count, 0)
logger.log_metric("test", 1.0, step=2)
created_experiment.log_metric.assert_called_once_with("test", x=2, y=1.0)
created_experiment.log_metric.reset_mock()
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_neptune_pickling(self, neptune):
unpickleable_run = Run()
logger = NeptuneLogger(run=unpickleable_run)
self.assertEqual(0, neptune.init.call_count)
logger.log_text("test", "text")
created_experiment.log_text.assert_called_once_with("test", "text")
created_experiment.log_text.reset_mock()
pickled_logger = pickle.dumps(logger)
unpickled = pickle.loads(pickled_logger)
logger.log_image("test", "image file")
created_experiment.log_image.assert_called_once_with("test", "image file")
created_experiment.log_image.reset_mock()
neptune.init.assert_called_once_with(project="test-project", api_token=None, run="TEST-42")
self.assertIsNotNone(unpickled.experiment)
logger.log_image("test", "image file", step=2)
created_experiment.log_image.assert_called_once_with("test", x=2, y="image file")
created_experiment.log_image.reset_mock()
@patch("pytorch_lightning.loggers.neptune.Run", Run)
def test_online_with_wrong_kwargs(self, neptune):
"""Tests combinations of kwargs together with `run` kwarg which makes some of other parameters unavailable
in init."""
with self.assertRaises(ValueError):
NeptuneLogger(run="some string")
logger.log_artifact("file")
created_experiment.log_artifact.assert_called_once_with("file", None)
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), project="redundant project")
logger.set_property("property", 10)
created_experiment.set_property.assert_called_once_with("property", 10)
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), api_key="redundant api key")
logger.append_tags("one tag")
created_experiment.append_tags.assert_called_once_with("one tag")
created_experiment.append_tags.reset_mock()
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), name="redundant api name")
logger.append_tags(["two", "tags"])
created_experiment.append_tags.assert_called_once_with("two", "tags")
with self.assertRaises(ValueError):
NeptuneLogger(run=Run(), foo="random **kwarg")
# this should work
NeptuneLogger(run=Run())
NeptuneLogger(project="foo")
NeptuneLogger(foo="bar")
@patch("pytorch_lightning.loggers.neptune.neptune")
def test_neptune_leave_open_experiment_after_fit(neptune, tmpdir):
"""Verify that neptune experiment was closed after training."""
model = BoringModel()
@staticmethod
def _get_logger_with_mocks(**kwargs):
logger = NeptuneLogger(**kwargs)
run_instance_mock = MagicMock()
logger._run_instance = run_instance_mock
logger._run_instance.__getitem__.return_value.fetch.return_value = "exp-name"
run_attr_mock = MagicMock()
logger._run_instance.__getitem__.return_value = run_attr_mock
def _run_training(logger):
logger._experiment = MagicMock()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0.05, logger=logger)
assert trainer.log_dir is None
return logger, run_instance_mock, run_attr_mock
def test_neptune_additional_methods(self, neptune):
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
logger.experiment["key1"].log(torch.ones(1))
run_instance_mock.__getitem__.assert_called_once_with("key1")
run_instance_mock.__getitem__().log.assert_called_once_with(torch.ones(1))
def _fit_and_test(self, logger, model):
trainer = Trainer(default_root_dir=self.tmpdir, max_epochs=1, limit_train_batches=0.05, logger=logger)
assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune")
trainer.fit(model)
assert trainer.log_dir is None
return logger
trainer.test(model)
assert trainer.log_dir == os.path.join(os.getcwd(), ".neptune")
logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
assert logger_close_after_fit._experiment.stop.call_count == 1
@pytest.mark.usefixtures("tmpdir_unittest_fixture")
def test_neptune_leave_open_experiment_after_fit(self, neptune):
"""Verify that neptune experiment was NOT closed after training."""
# given
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
logger_open_after_fit = _run_training(NeptuneLogger(offline_mode=True, close_after_fit=False))
assert logger_open_after_fit._experiment.stop.call_count == 0
# when
self._fit_and_test(
logger=logger,
model=BoringModel(),
)
# then
assert run_instance_mock.stop.call_count == 0
@pytest.mark.usefixtures("tmpdir_unittest_fixture")
def test_neptune_log_metrics_on_trained_model(self, neptune):
"""Verify that trained models do log data."""
# given
class LoggingModel(BoringModel):
def validation_epoch_end(self, outputs):
self.log("some/key", 42)
# and
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project")
# when
self._fit_and_test(
logger=logger,
model=LoggingModel(),
)
# then
run_instance_mock.__getitem__.assert_any_call("training/some/key")
run_instance_mock.__getitem__.return_value.log.assert_has_calls([call(42)])
def test_log_hyperparams(self, neptune):
params = {"foo": "bar", "nested_foo": {"bar": 42}}
test_variants = [
({}, "training/hyperparams"),
({"prefix": "custom_prefix"}, "custom_prefix/hyperparams"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/hyperparams"),
]
for prefix, hyperparams_key in test_variants:
# given:
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix)
# when: log hyperparams
logger.log_hyperparams(params)
# then
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 0)
run_instance_mock.__setitem__.assert_called_once_with(hyperparams_key, params)
def test_log_metrics(self, neptune):
metrics = {
"foo": 42,
"bar": 555,
}
test_variants = [
({}, ("training/foo", "training/bar")),
({"prefix": "custom_prefix"}, ("custom_prefix/foo", "custom_prefix/bar")),
({"prefix": "custom/nested/prefix"}, ("custom/nested/prefix/foo", "custom/nested/prefix/bar")),
]
for prefix, (metrics_foo_key, metrics_bar_key) in test_variants:
# given:
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
api_key="test", project="project", **prefix
)
# when: log metrics
logger.log_metrics(metrics)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 0)
self.assertEqual(run_instance_mock.__getitem__.call_count, 2)
run_instance_mock.__getitem__.assert_any_call(metrics_foo_key)
run_instance_mock.__getitem__.assert_any_call(metrics_bar_key)
run_attr_mock.log.assert_has_calls([call(42), call(555)])
def test_log_model_summary(self, neptune):
model = BoringModel()
test_variants = [
({}, "training/model/summary"),
({"prefix": "custom_prefix"}, "custom_prefix/model/summary"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model/summary"),
]
for prefix, model_summary_key in test_variants:
# given:
logger, run_instance_mock, _ = self._get_logger_with_mocks(api_key="test", project="project", **prefix)
file_from_content_mock = neptune.types.File.from_content()
# when: log metrics
logger.log_model_summary(model)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 0)
run_instance_mock.__setitem__.assert_called_once_with(model_summary_key, file_from_content_mock)
def test_after_save_checkpoint(self, neptune):
test_variants = [
({}, "training/model"),
({"prefix": "custom_prefix"}, "custom_prefix/model"),
({"prefix": "custom/nested/prefix"}, "custom/nested/prefix/model"),
]
for prefix, model_key_prefix in test_variants:
# given:
logger, run_instance_mock, run_attr_mock = self._get_logger_with_mocks(
api_key="test", project="project", **prefix
)
cb_mock = MagicMock(
dirpath="path/to/models",
last_model_path="path/to/models/last",
best_k_models={
"path/to/models/model1": None,
"path/to/models/model2/with/slashes": None,
},
best_model_path="path/to/models/best_model",
best_model_score=None,
)
# when: save checkpoint
logger.after_save_checkpoint(cb_mock)
# then:
self.assertEqual(run_instance_mock.__setitem__.call_count, 1)
self.assertEqual(run_instance_mock.__getitem__.call_count, 3)
self.assertEqual(run_attr_mock.upload.call_count, 3)
run_instance_mock.__setitem__.assert_called_once_with(
f"{model_key_prefix}/best_model_path", "path/to/models/best_model"
)
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/last")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1")
run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes")
run_attr_mock.upload.assert_has_calls(
[
call("path/to/models/last"),
call("path/to/models/model1"),
call("path/to/models/model2/with/slashes"),
]
)
def test_save_dir(self, neptune):
# given
logger = NeptuneLogger(api_key="test", project="project")
# expect
self.assertEqual(logger.save_dir, os.path.join(os.getcwd(), ".neptune"))
class TestNeptuneLoggerDeprecatedUsages(unittest.TestCase):
@staticmethod
def _assert_legacy_usage(callback, *args, **kwargs):
with pytest.raises(ValueError):
callback(*args, **kwargs)
def test_legacy_kwargs(self):
legacy_neptune_kwargs = [
# NeptuneLegacyLogger kwargs
"project_name",
"offline_mode",
"experiment_name",
"experiment_id",
"params",
"properties",
"upload_source_files",
"abort_callback",
"logger",
"upload_stdout",
"upload_stderr",
"send_hardware_metrics",
"run_monitoring_thread",
"handle_uncaught_exceptions",
"git_info",
"hostname",
"notebook_id",
"notebook_path",
# NeptuneLogger from neptune-pytorch-lightning package kwargs
"base_namespace",
"close_after_fit",
]
for legacy_kwarg in legacy_neptune_kwargs:
self._assert_legacy_usage(NeptuneLogger, **{legacy_kwarg: None})
@patch("pytorch_lightning.loggers.neptune.warnings")
@patch("pytorch_lightning.loggers.neptune.NeptuneFile")
@patch("pytorch_lightning.loggers.neptune.neptune")
def test_legacy_functions(self, neptune, neptune_file_mock, warnings_mock):
logger = NeptuneLogger(api_key="test", project="project")
# test deprecated functions which will be shut down in pytorch-lightning 1.7.0
attr_mock = logger._run_instance.__getitem__
attr_mock.reset_mock()
fake_image = {}
logger.log_metric("metric", 42)
logger.log_text("text", "some string")
logger.log_image("image_obj", fake_image)
logger.log_image("image_str", "img/path")
logger.log_artifact("artifact", "some/path")
assert attr_mock.call_count == 5
assert warnings_mock.warn.call_count == 5
attr_mock.assert_has_calls(
[
call("training/metric"),
call().log(42, step=None),
call("training/text"),
call().log("some string", step=None),
call("training/image_obj"),
call().log(fake_image, step=None),
call("training/image_str"),
call().log(neptune_file_mock(), step=None),
call("training/artifacts/artifact"),
call().log("some/path"),
]
)
# test Exception raising functions functions
self._assert_legacy_usage(logger.set_property)
self._assert_legacy_usage(logger.append_tags)
class TestNeptuneLoggerUtils(unittest.TestCase):
def test__get_full_model_name(self):
# given:
SimpleCheckpoint = namedtuple("SimpleCheckpoint", ["dirpath"])
test_input_data = [
("key.ext", "foo/bar/key.ext", SimpleCheckpoint(dirpath="foo/bar")),
("key/in/parts.ext", "foo/bar/key/in/parts.ext", SimpleCheckpoint(dirpath="foo/bar")),
]
# expect:
for expected_model_name, *key_and_path in test_input_data:
self.assertEqual(NeptuneLogger._get_full_model_name(*key_and_path), expected_model_name)
def test__get_full_model_names_from_exp_structure(self):
# given:
input_dict = {
"foo": {
"bar": {
"lvl1_1": {"lvl2": {"lvl3_1": "some non important value", "lvl3_2": "some non important value"}},
"lvl1_2": "some non important value",
},
"other_non_important": {"val100": 100},
},
"other_non_important": {"val42": 42},
}
expected_keys = {"lvl1_1/lvl2/lvl3_1", "lvl1_1/lvl2/lvl3_2", "lvl1_2"}
# expect:
self.assertEqual(NeptuneLogger._get_full_model_names_from_exp_structure(input_dict, "foo/bar"), expected_keys)