neptune online (#1499)

This commit is contained in:
Jirka Borovec 2020-04-15 17:14:29 +02:00 committed by GitHub
parent b3fe17ddeb
commit 8322f1b039
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 10 deletions

View File

@ -33,7 +33,7 @@ class NeptuneLogger(LightningLoggerBase):
api_key: Optional[str] = None,
project_name: Optional[str] = None,
close_after_fit: Optional[bool] = True,
offline_mode: bool = True,
offline_mode: bool = False,
experiment_name: Optional[str] = None,
upload_source_files: Optional[List[str]] = None,
params: Optional[Dict[str, Any]] = None,
@ -140,7 +140,7 @@ class NeptuneLogger(LightningLoggerBase):
"namespace/project_name" for example "tom/minst-classification".
If None, the value of NEPTUNE_PROJECT environment variable will be taken.
You need to create the project in https://neptune.ai first.
offline_mode: Optional default True. If offline_mode=True no logs will be send
offline_mode: Optional default False. If offline_mode=True no logs will be send
to neptune. Usually used for debug and test purposes.
close_after_fit: Optional default True. If close_after_fit=False the experiment
will not be closed after training and additional metrics,

View File

@ -10,6 +10,15 @@ from pytorch_lightning.loggers import (
from tests.base import LightningTestModel
def _get_logger_args(logger_class, save_dir):
logger_args = {}
if 'save_dir' in inspect.getfullargspec(logger_class).args:
logger_args.update(save_dir=str(save_dir))
if 'offline_mode' in inspect.getfullargspec(logger_class).args:
logger_args.update(offline_mode=True)
return logger_args
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
@ -40,10 +49,8 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
super().log_metrics(metrics, step)
self.history.append((step, metrics))
if 'save_dir' in inspect.getfullargspec(logger_class).args:
logger = StoreHistoryLogger(save_dir=str(tmpdir))
else:
logger = StoreHistoryLogger()
logger_args = _get_logger_args(logger_class, tmpdir)
logger = StoreHistoryLogger(**logger_args)
trainer = Trainer(
max_epochs=1,
@ -80,10 +87,8 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
if 'save_dir' in inspect.getfullargspec(logger_class).args:
logger = logger_class(save_dir=str(tmpdir))
else:
logger = logger_class()
logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
trainer = Trainer(
max_epochs=1,