51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
import pickle
|
|
|
|
import tests.base.utils as tutils
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.loggers import TrainsLogger
|
|
from tests.base import EvalModelTemplate
|
|
|
|
|
|
def test_trains_logger(tmpdir):
|
|
"""Verify that basic functionality of TRAINS logger works."""
|
|
model = EvalModelTemplate()
|
|
TrainsLogger.set_bypass_mode(True)
|
|
TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008',
|
|
files_host='http://integration.trains.allegro.ai:8081',
|
|
web_host='http://integration.trains.allegro.ai:8080', )
|
|
logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test")
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
train_percent_check=0.05,
|
|
logger=logger
|
|
)
|
|
result = trainer.fit(model)
|
|
|
|
print('result finished')
|
|
logger.finalize()
|
|
assert result == 1, "Training failed"
|
|
|
|
|
|
def test_trains_pickle(tmpdir):
|
|
"""Verify that pickling trainer with TRAINS logger works."""
|
|
# hparams = tutils.get_default_hparams()
|
|
# model = LightningTestModel(hparams)
|
|
TrainsLogger.set_bypass_mode(True)
|
|
TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008',
|
|
files_host='http://integration.trains.allegro.ai:8081',
|
|
web_host='http://integration.trains.allegro.ai:8080', )
|
|
logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test")
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
logger=logger
|
|
)
|
|
pkl_bytes = pickle.dumps(trainer)
|
|
trainer2 = pickle.loads(pkl_bytes)
|
|
trainer2.logger.log_metrics({"acc": 1.0})
|
|
trainer2.logger.finalize()
|
|
logger.finalize()
|