lightning/tests/loggers/test_comet.py

108 lines
3.2 KiB
Python
Raw Normal View History

import os
from unittest.mock import patch
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
def test_comet_logger_online():
"""Test comet online with mocks."""
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(
api_key='key',
workspace='dummy-test',
project_name='general'
)
_ = logger.experiment
comet.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general'
)
# Test both given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(
save_dir='test',
api_key='key',
workspace='dummy-test',
project_name='general'
)
_ = logger.experiment
comet.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general'
)
# Test neither given
with pytest.raises(MisconfigurationException):
CometLogger(
workspace='dummy-test',
project_name='general'
)
# Test already exists
with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing:
logger = CometLogger(
experiment_key='test',
experiment_name='experiment',
api_key='key',
workspace='dummy-test',
project_name='general'
)
_ = logger.experiment
comet_existing.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general',
previous_experiment='test'
)
comet_existing().set_name.assert_called_once_with('experiment')
with patch('pytorch_lightning.loggers.comet.API') as api:
CometLogger(
api_key='key',
workspace='dummy-test',
project_name='general',
rest_api_key='rest'
)
api.assert_called_once_with('rest')
def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
""" Test that the logger creates the folders and files in the right place. """
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
logger = CometLogger(project_name='test', save_dir=tmpdir)
assert not os.listdir(tmpdir)
assert logger.mode == 'offline'
assert logger.save_dir == tmpdir
_ = logger.experiment
version = logger.version
assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'}
model = EvalModelTemplate()
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}