2020-03-14 17:02:14 +00:00
|
|
|
import pickle
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
import tests.base.utils as tutils
|
2020-03-14 17:02:14 +00:00
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.loggers import TrainsLogger
|
2020-03-25 11:46:27 +00:00
|
|
|
from tests.base import LightningTestModel
|
2020-03-14 17:02:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_trains_logger(tmpdir):
|
|
|
|
"""Verify that basic functionality of TRAINS logger works."""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
hparams = tutils.get_default_hparams()
|
2020-03-14 17:02:14 +00:00
|
|
|
model = LightningTestModel(hparams)
|
2020-03-19 13:14:29 +00:00
|
|
|
TrainsLogger.set_bypass_mode(True)
|
|
|
|
TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008',
|
|
|
|
files_host='http://integration.trains.allegro.ai:8081',
|
|
|
|
web_host='http://integration.trains.allegro.ai:8080', )
|
|
|
|
logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test")
|
2020-03-14 17:02:14 +00:00
|
|
|
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
train_percent_check=0.05,
|
|
|
|
logger=logger
|
|
|
|
)
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
print('result finished')
|
2020-03-19 13:14:29 +00:00
|
|
|
logger.finalize()
|
2020-03-14 17:02:14 +00:00
|
|
|
assert result == 1, "Training failed"
|
|
|
|
|
|
|
|
|
|
|
|
def test_trains_pickle(tmpdir):
|
|
|
|
"""Verify that pickling trainer with TRAINS logger works."""
|
|
|
|
tutils.reset_seed()
|
|
|
|
|
2020-03-25 11:46:27 +00:00
|
|
|
# hparams = tutils.get_default_hparams()
|
2020-03-14 17:02:14 +00:00
|
|
|
# model = LightningTestModel(hparams)
|
2020-03-19 13:14:29 +00:00
|
|
|
TrainsLogger.set_bypass_mode(True)
|
|
|
|
TrainsLogger.set_credentials(api_host='http://integration.trains.allegro.ai:8008',
|
|
|
|
files_host='http://integration.trains.allegro.ai:8081',
|
|
|
|
web_host='http://integration.trains.allegro.ai:8080', )
|
|
|
|
logger = TrainsLogger(project_name="lightning_log", task_name="pytorch lightning test")
|
2020-03-14 17:02:14 +00:00
|
|
|
|
|
|
|
trainer_options = dict(
|
|
|
|
default_save_path=tmpdir,
|
|
|
|
max_epochs=1,
|
|
|
|
logger=logger
|
|
|
|
)
|
|
|
|
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
pkl_bytes = pickle.dumps(trainer)
|
|
|
|
trainer2 = pickle.loads(pkl_bytes)
|
|
|
|
trainer2.logger.log_metrics({"acc": 1.0})
|
2020-03-19 13:14:29 +00:00
|
|
|
trainer2.logger.finalize()
|
|
|
|
logger.finalize()
|