52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
import pickle
|
|
|
|
import tests.base.utils as tutils
|
|
from pytorch_lightning import Trainer
|
|
from tests.base import LightningTestModel
|
|
|
|
|
|
def test_testtube_logger(tmpdir):
|
|
"""Verify that basic functionality of test tube logger works."""
|
|
tutils.reset_seed()
|
|
hparams = tutils.get_default_hparams()
|
|
model = LightningTestModel(hparams)
|
|
|
|
logger = tutils.get_default_testtube_logger(tmpdir, False)
|
|
|
|
assert logger.name == 'lightning_logs'
|
|
|
|
trainer_options = dict(
|
|
default_save_path=tmpdir,
|
|
max_epochs=1,
|
|
train_percent_check=0.05,
|
|
logger=logger
|
|
)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
result = trainer.fit(model)
|
|
|
|
assert result == 1, 'Training failed'
|
|
|
|
|
|
def test_testtube_pickle(tmpdir):
|
|
"""Verify that pickling a trainer containing a test tube logger works."""
|
|
tutils.reset_seed()
|
|
|
|
hparams = tutils.get_default_hparams()
|
|
|
|
logger = tutils.get_default_testtube_logger(tmpdir, False)
|
|
logger.log_hyperparams(hparams)
|
|
logger.save()
|
|
|
|
trainer_options = dict(
|
|
default_save_path=tmpdir,
|
|
max_epochs=1,
|
|
train_percent_check=0.05,
|
|
logger=logger
|
|
)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
pkl_bytes = pickle.dumps(trainer)
|
|
trainer2 = pickle.loads(pkl_bytes)
|
|
trainer2.logger.log_metrics({'acc': 1.0})
|