neptune online (#1499)
This commit is contained in:
parent
b3fe17ddeb
commit
8322f1b039
|
@ -33,7 +33,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
api_key: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
close_after_fit: Optional[bool] = True,
|
||||
offline_mode: bool = True,
|
||||
offline_mode: bool = False,
|
||||
experiment_name: Optional[str] = None,
|
||||
upload_source_files: Optional[List[str]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
|
@ -140,7 +140,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
"namespace/project_name" for example "tom/minst-classification".
|
||||
If None, the value of NEPTUNE_PROJECT environment variable will be taken.
|
||||
You need to create the project in https://neptune.ai first.
|
||||
offline_mode: Optional default True. If offline_mode=True no logs will be send
|
||||
offline_mode: Optional default False. If offline_mode=True no logs will be send
|
||||
to neptune. Usually used for debug and test purposes.
|
||||
close_after_fit: Optional default True. If close_after_fit=False the experiment
|
||||
will not be closed after training and additional metrics,
|
||||
|
|
|
@ -10,6 +10,15 @@ from pytorch_lightning.loggers import (
|
|||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
def _get_logger_args(logger_class, save_dir):
|
||||
logger_args = {}
|
||||
if 'save_dir' in inspect.getfullargspec(logger_class).args:
|
||||
logger_args.update(save_dir=str(save_dir))
|
||||
if 'offline_mode' in inspect.getfullargspec(logger_class).args:
|
||||
logger_args.update(offline_mode=True)
|
||||
return logger_args
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logger_class", [
|
||||
TensorBoardLogger,
|
||||
CometLogger,
|
||||
|
@ -40,10 +49,8 @@ def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
|
|||
super().log_metrics(metrics, step)
|
||||
self.history.append((step, metrics))
|
||||
|
||||
if 'save_dir' in inspect.getfullargspec(logger_class).args:
|
||||
logger = StoreHistoryLogger(save_dir=str(tmpdir))
|
||||
else:
|
||||
logger = StoreHistoryLogger()
|
||||
logger_args = _get_logger_args(logger_class, tmpdir)
|
||||
logger = StoreHistoryLogger(**logger_args)
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
|
@ -80,10 +87,8 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
|
|||
import atexit
|
||||
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
||||
|
||||
if 'save_dir' in inspect.getfullargspec(logger_class).args:
|
||||
logger = logger_class(save_dir=str(tmpdir))
|
||||
else:
|
||||
logger = logger_class()
|
||||
logger_args = _get_logger_args(logger_class, tmpdir)
|
||||
logger = logger_class(**logger_args)
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
|
|
Loading…
Reference in New Issue