157 lines
4.1 KiB
Python
157 lines
4.1 KiB
Python
import os
|
|
import pickle
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import tests.base.utils as tutils
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.loggers import CometLogger
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests.base import LightningTestModel
|
|
|
|
|
|
def test_comet_logger(tmpdir, monkeypatch):
|
|
"""Verify that basic functionality of Comet.ml logger works."""
|
|
|
|
# prevent comet logger from trying to print at exit, since
|
|
# pytest's stdout/stderr redirection breaks it
|
|
import atexit
|
|
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
|
|
|
tutils.reset_seed()
|
|
|
|
hparams = tutils.get_default_hparams()
|
|
model = LightningTestModel(hparams)
|
|
|
|
comet_dir = os.path.join(tmpdir, 'cometruns')
|
|
|
|
# We test CometLogger in offline mode with local saves
|
|
logger = CometLogger(
|
|
save_dir=comet_dir,
|
|
project_name='general',
|
|
workspace='dummy-test',
|
|
)
|
|
|
|
trainer_options = dict(
|
|
default_save_path=tmpdir,
|
|
max_epochs=1,
|
|
train_percent_check=0.05,
|
|
logger=logger
|
|
)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
result = trainer.fit(model)
|
|
trainer.logger.log_metrics({'acc': torch.ones(1)})
|
|
|
|
assert result == 1, 'Training failed'
|
|
|
|
|
|
def test_comet_logger_online():
|
|
"""Test comet online with mocks."""
|
|
# Test api_key given
|
|
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
|
|
logger = CometLogger(
|
|
api_key='key',
|
|
workspace='dummy-test',
|
|
project_name='general'
|
|
)
|
|
|
|
_ = logger.experiment
|
|
|
|
comet.assert_called_once_with(
|
|
api_key='key',
|
|
workspace='dummy-test',
|
|
project_name='general'
|
|
)
|
|
|
|
# Test both given
|
|
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
|
|
logger = CometLogger(
|
|
save_dir='test',
|
|
api_key='key',
|
|
workspace='dummy-test',
|
|
project_name='general'
|
|
)
|
|
|
|
_ = logger.experiment
|
|
|
|
comet.assert_called_once_with(
|
|
api_key='key',
|
|
workspace='dummy-test',
|
|
project_name='general'
|
|
)
|
|
|
|
# Test neither given
|
|
with pytest.raises(MisconfigurationException):
|
|
CometLogger(
|
|
workspace='dummy-test',
|
|
project_name='general'
|
|
)
|
|
|
|
# Test already exists
|
|
with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing:
|
|
logger = CometLogger(
|
|
experiment_key='test',
|
|
experiment_name='experiment',
|
|
api_key='key',
|
|
workspace='dummy-test',
|
|
project_name='general'
|
|
)
|
|
|
|
_ = logger.experiment
|
|
|
|
comet_existing.assert_called_once_with(
|
|
api_key='key',
|
|
workspace='dummy-test',
|
|
project_name='general',
|
|
previous_experiment='test'
|
|
)
|
|
|
|
comet_existing().set_name.assert_called_once_with('experiment')
|
|
|
|
with patch('pytorch_lightning.loggers.comet.API') as api:
|
|
CometLogger(
|
|
api_key='key',
|
|
workspace='dummy-test',
|
|
project_name='general',
|
|
rest_api_key='rest'
|
|
)
|
|
|
|
api.assert_called_once_with('rest')
|
|
|
|
|
|
def test_comet_pickle(tmpdir, monkeypatch):
|
|
"""Verify that pickling trainer with comet logger works."""
|
|
|
|
# prevent comet logger from trying to print at exit, since
|
|
# pytest's stdout/stderr redirection breaks it
|
|
import atexit
|
|
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
|
|
|
tutils.reset_seed()
|
|
|
|
# hparams = tutils.get_default_hparams()
|
|
# model = LightningTestModel(hparams)
|
|
|
|
comet_dir = os.path.join(tmpdir, 'cometruns')
|
|
|
|
# We test CometLogger in offline mode with local saves
|
|
logger = CometLogger(
|
|
save_dir=comet_dir,
|
|
project_name='general',
|
|
workspace='dummy-test',
|
|
)
|
|
|
|
trainer_options = dict(
|
|
default_save_path=tmpdir,
|
|
max_epochs=1,
|
|
logger=logger
|
|
)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
pkl_bytes = pickle.dumps(trainer)
|
|
trainer2 = pickle.loads(pkl_bytes)
|
|
trainer2.logger.log_metrics({'acc': 1.0})
|