fix flushing loggers (#1459)

* flushing loggers

* flushing loggers

* flushing loggers

* flushing loggers

* changelog

* typo

* fix trains

* optimize imports

* add logger test all

* add logger test pickle

* flake8

* fix benchmark

* hanging loggers

* try

* del

* all

* cleaning
This commit is contained in:
Jirka Borovec 2020-04-15 02:32:33 +02:00 committed by GitHub
parent c96c6a6b33
commit b3fe17ddeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 209 additions and 334 deletions

View File

@ -28,7 +28,7 @@ jobs:
requires: 'minimal'
# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 30
timeout-minutes: 15
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}

View File

@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
-
- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))
-

View File

@ -48,7 +48,7 @@ class LightningLoggerBase(ABC):
`LightningLoggerBase.agg_and_log_metrics` method.
"""
self._rank = 0
self._prev_step = -1
self._prev_step: int = -1
self._metrics_to_agg: List[Dict[str, float]] = []
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func
@ -98,15 +98,15 @@ class LightningLoggerBase(ABC):
return step, None
# compute the metrics
agg_step, agg_mets = self._finalize_agg_metrics()
agg_step, agg_mets = self._reduce_agg_metrics()
# as new step received reset accumulator
self._metrics_to_agg = [metrics]
self._prev_step = step
return agg_step, agg_mets
def _finalize_agg_metrics(self):
"""Aggregate accumulated metrics. This shall be called in close."""
def _reduce_agg_metrics(self):
"""Aggregate accumulated metrics."""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
@ -116,6 +116,14 @@ class LightningLoggerBase(ABC):
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
return self._prev_step, agg_mets
def _finalize_agg_metrics(self):
"""This shall be called before save/close."""
agg_step, metrics_to_log = self._reduce_agg_metrics()
self._metrics_to_agg = []
if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Aggregates and records metrics.
This method doesn't log the passed metrics instantaneously, but instead
@ -219,7 +227,7 @@ class LightningLoggerBase(ABC):
def save(self) -> None:
"""Save log data."""
pass
self._finalize_agg_metrics()
def finalize(self, status: str) -> None:
"""Do any processing that is necessary to finalize an experiment.
@ -227,14 +235,11 @@ class LightningLoggerBase(ABC):
Args:
status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
pass
self.save()
def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
agg_step, metrics_to_log = self._finalize_agg_metrics()
if metrics_to_log is not None:
self.log_metrics(metrics=metrics_to_log, step=agg_step)
self.save()
@property
def rank(self) -> int:
@ -292,7 +297,6 @@ class LoggerCollection(LightningLoggerBase):
@LightningLoggerBase.rank.setter
def rank(self, value: int) -> None:
self._rank = value
for logger in self._logger_iterable:
logger.rank = value

View File

@ -36,10 +36,15 @@ class CometLogger(LightningLoggerBase):
Log using `comet.ml <https://www.comet.ml>`_.
"""
def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None,
workspace: Optional[str] = None, project_name: Optional[str] = None,
rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None, **kwargs):
def __init__(self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
**kwargs):
r"""
Requires either an API Key (online mode) or a local directory path (offline mode)
@ -118,6 +123,7 @@ class CometLogger(LightningLoggerBase):
self.name = experiment_name
except TypeError as e:
log.exception("Failed to set experiment name for comet.ml logger")
self._kwargs = kwargs
@property
def experiment(self) -> CometBaseExperiment:
@ -197,7 +203,7 @@ class CometLogger(LightningLoggerBase):
@property
def name(self) -> str:
return self.experiment.project_name
return str(self.experiment.project_name)
@name.setter
def name(self, value: str) -> None:

View File

@ -23,6 +23,7 @@ Use the logger anywhere in you LightningModule as follows:
self.logger.experiment.whatever_ml_flow_supports(...)
"""
import os
from argparse import Namespace
from time import time
from typing import Optional, Dict, Any, Union
@ -39,10 +40,14 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
class MLFlowLogger(LightningLoggerBase):
def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None,
tags: Dict[str, Any] = None):
r"""
"""MLFLow logger"""
def __init__(self,
experiment_name: str = 'default',
tracking_uri: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = None):
r"""
Logs using MLFlow
Args:
@ -51,6 +56,8 @@ class MLFlowLogger(LightningLoggerBase):
tags (dict): todo this param
"""
super().__init__()
if not tracking_uri and save_dir:
tracking_uri = f'file:{os.sep * 2}{save_dir}'
self._mlflow_client = MlflowClient(tracking_uri)
self.experiment_name = experiment_name
self._run_id = None
@ -59,7 +66,6 @@ class MLFlowLogger(LightningLoggerBase):
@property
def experiment(self) -> MlflowClient:
r"""
Actual mlflow object. To use mlflow features do the following.
Example::
@ -102,11 +108,9 @@ class MLFlowLogger(LightningLoggerBase):
continue
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
def save(self):
pass
@rank_zero_only
def finalize(self, status: str = 'FINISHED') -> None:
super().finalize(status)
if status == 'success':
status = 'FINISHED'
self.experiment.set_terminated(self.run_id, status)

View File

@ -29,13 +29,18 @@ class NeptuneLogger(LightningLoggerBase):
To log experiment data in online mode, NeptuneLogger requries an API key:
"""
def __init__(self, api_key: Optional[str] = None, project_name: Optional[str] = None,
close_after_fit: Optional[bool] = True, offline_mode: bool = False,
def __init__(self,
api_key: Optional[str] = None,
project_name: Optional[str] = None,
close_after_fit: Optional[bool] = True,
offline_mode: bool = True,
experiment_name: Optional[str] = None,
upload_source_files: Optional[List[str]] = None, params: Optional[Dict[str, Any]] = None,
properties: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, **kwargs):
upload_source_files: Optional[List[str]] = None,
params: Optional[Dict[str, Any]] = None,
properties: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
**kwargs):
r"""
Initialize a neptune.ai logger.
.. note:: Requires either an API Key (online mode) or a local directory path (offline mode)
@ -135,8 +140,8 @@ class NeptuneLogger(LightningLoggerBase):
"namespace/project_name" for example "tom/minst-classification".
If None, the value of NEPTUNE_PROJECT environment variable will be taken.
You need to create the project in https://neptune.ai first.
offline_mode: Optional default False. If offline_mode=True no logs will be send
to neptune. Usually used for debug purposes.
offline_mode: Optional default True. If offline_mode=True no logs will be send
to neptune. Usually used for debug and test purposes.
close_after_fit: Optional default True. If close_after_fit=False the experiment
will not be closed after training and additional metrics,
images or artifacts can be logged. Also, remember to close the experiment explicitly
@ -243,6 +248,7 @@ class NeptuneLogger(LightningLoggerBase):
@rank_zero_only
def finalize(self, status: str) -> None:
super().finalize(status)
if self.close_after_fit:
self.experiment.stop()

View File

@ -14,7 +14,6 @@ from pytorch_lightning import _logger as log
class TensorBoardLogger(LightningLoggerBase):
r"""
Log to local file system in TensorBoard format
Implemented using :class:`torch.utils.tensorboard.SummaryWriter`. Logs are saved to
@ -40,10 +39,11 @@ class TensorBoardLogger(LightningLoggerBase):
"""
NAME_CSV_TAGS = 'meta_tags.csv'
def __init__(
self, save_dir: str, name: Optional[str] = "default",
version: Optional[Union[int, str]] = None, **kwargs
):
def __init__(self,
save_dir: str,
name: Optional[str] = "default",
version: Optional[Union[int, str]] = None,
**kwargs):
super().__init__()
self.save_dir = save_dir
self._name = name
@ -51,7 +51,7 @@ class TensorBoardLogger(LightningLoggerBase):
self._experiment = None
self.tags = {}
self.kwargs = kwargs
self._kwargs = kwargs
@property
def root_dir(self) -> str:
@ -92,7 +92,7 @@ class TensorBoardLogger(LightningLoggerBase):
return self._experiment
os.makedirs(self.root_dir, exist_ok=True)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self.kwargs)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment
@rank_zero_only
@ -127,6 +127,7 @@ class TensorBoardLogger(LightningLoggerBase):
@rank_zero_only
def save(self) -> None:
super().save()
try:
self.experiment.flush()
except AttributeError:

View File

@ -18,10 +18,13 @@ class TestTubeLogger(LightningLoggerBase):
__test__ = False
def __init__(
self, save_dir: str, name: str = "default", description: Optional[str] = None,
debug: bool = False, version: Optional[int] = None, create_git_tag: bool = False
):
def __init__(self,
save_dir: str,
name: str = "default",
description: Optional[str] = None,
debug: bool = False,
version: Optional[int] = None,
create_git_tag: bool = False):
r"""
Example
@ -105,12 +108,14 @@ class TestTubeLogger(LightningLoggerBase):
@rank_zero_only
def save(self) -> None:
super().save()
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
self.experiment.save()
@rank_zero_only
def finalize(self, status: str) -> None:
super().finalize(status)
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
self.save()
@ -118,6 +123,7 @@ class TestTubeLogger(LightningLoggerBase):
@rank_zero_only
def close(self) -> None:
super().save()
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
if not self.debug:

View File

@ -295,11 +295,9 @@ class TrainsLogger(LightningLoggerBase):
delete_after_upload=delete_after_upload
)
def save(self) -> None:
pass
@rank_zero_only
def finalize(self, status: str = None) -> None:
# super().finalize(status)
if self.bypass_mode() or not self._trains:
return

View File

@ -46,11 +46,18 @@ class WandbLogger(LightningLoggerBase):
trainer = Trainer(logger=wandb_logger)
"""
def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None,
offline: bool = False, id: Optional[str] = None, anonymous: bool = False,
version: Optional[str] = None, project: Optional[str] = None,
tags: Optional[List[str]] = None, log_model: bool = False,
experiment=None, entity=None):
def __init__(self,
name: Optional[str] = None,
save_dir: Optional[str] = None,
offline: bool = False,
id: Optional[str] = None,
anonymous: bool = False,
version: Optional[str] = None,
project: Optional[str] = None,
tags: Optional[List[str]] = None,
log_model: bool = False,
experiment=None,
entity=None):
super().__init__()
self._name = name
self._save_dir = save_dir

View File

@ -370,14 +370,13 @@ class TrainerEvaluationLoopMixin(ABC):
# run evaluation
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
eval_results)
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)
# add metrics to prog bar
self.add_tqdm_metrics(prog_bar_metrics)
# log results of test
if test_mode and self.proc_rank == 0 and len(callback_metrics) > 0:
if test_mode and self.proc_rank == 0:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)

View File

@ -293,8 +293,7 @@ class Trainer(
# benchmarking
self.benchmark = benchmark
if benchmark:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark = self.benchmark
# Transfer params
self.num_nodes = num_nodes

View File

@ -89,8 +89,8 @@ class LightValidationMixin(LightValidationStepMixin):
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
results = {'progress_bar': tqdm_dict, 'log': tqdm_dict}
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
return results
@ -355,8 +355,8 @@ class LightTestMixin(LightTestStepMixin):
test_loss_mean /= len(outputs)
test_acc_mean /= len(outputs)
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
result = {'progress_bar': tqdm_dict}
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
return result

95
tests/loggers/test_all.py Normal file
View File

@ -0,0 +1,95 @@
import inspect
import pickle
import pytest
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import (
TensorBoardLogger, MLFlowLogger, NeptuneLogger, TestTubeLogger, CometLogger)
from tests.base import LightningTestModel
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
# TrainsLogger, # TODO: add this one
# WandbLogger, # TODO: add this one
])
def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
"""Verify that basic functionality of all loggers."""
tutils.reset_seed()
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
class StoreHistoryLogger(logger_class):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.history = []
def log_metrics(self, metrics, step):
super().log_metrics(metrics, step)
self.history.append((step, metrics))
if 'save_dir' in inspect.getfullargspec(logger_class).args:
logger = StoreHistoryLogger(save_dir=str(tmpdir))
else:
logger = StoreHistoryLogger()
trainer = Trainer(
max_epochs=1,
logger=logger,
train_percent_check=0.2,
val_percent_check=0.5,
fast_dev_run=True,
)
trainer.fit(model)
trainer.test()
log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
assert log_metric_names == [(0, ['val_acc', 'val_loss']),
(0, ['train_some_val']),
(1, ['test_acc', 'test_loss'])]
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
# TrainsLogger, # TODO: add this one
# WandbLogger, # TODO: add this one
])
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
"""Verify that pickling trainer with logger works."""
tutils.reset_seed()
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
if 'save_dir' in inspect.getfullargspec(logger_class).args:
logger = logger_class(save_dir=str(tmpdir))
else:
logger = logger_class()
trainer = Trainer(
max_epochs=1,
logger=logger
)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})

View File

@ -1,5 +1,4 @@
import pickle
from collections import OrderedDict
from unittest.mock import MagicMock
import numpy as np
@ -59,18 +58,6 @@ class CustomLogger(LightningLoggerBase):
return "1"
class StoreHistoryLogger(CustomLogger):
def __init__(self):
super().__init__()
self.history = {}
@rank_zero_only
def log_metrics(self, metrics, step):
if step not in self.history:
self.history[step] = {}
self.history[step].update(metrics)
def test_custom_logger(tmpdir):
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
@ -175,6 +162,18 @@ def test_adding_step_key(tmpdir):
def test_with_accumulate_grad_batches():
"""Checks if the logging is performed once for `accumulate_grad_batches` steps."""
class StoreHistoryLogger(CustomLogger):
def __init__(self):
super().__init__()
self.history = {}
@rank_zero_only
def log_metrics(self, metrics, step):
if step not in self.history:
self.history[step] = {}
self.history[step].update(metrics)
logger = StoreHistoryLogger()
np.random.seed(42)

View File

@ -1,51 +1,9 @@
import os
import pickle
from unittest.mock import patch
import pytest
import torch
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel
def test_comet_logger(tmpdir, monkeypatch):
"""Verify that basic functionality of Comet.ml logger works."""
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
tutils.reset_seed()
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
comet_dir = os.path.join(tmpdir, 'cometruns')
# We test CometLogger in offline mode with local saves
logger = CometLogger(
save_dir=comet_dir,
project_name='general',
workspace='dummy-test',
)
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
trainer.logger.log_metrics({'acc': torch.ones(1)})
assert result == 1, 'Training failed'
def test_comet_logger_online():
@ -120,37 +78,3 @@ def test_comet_logger_online():
)
api.assert_called_once_with('rest')
def test_comet_pickle(tmpdir, monkeypatch):
"""Verify that pickling trainer with comet logger works."""
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
tutils.reset_seed()
# hparams = tutils.get_default_hparams()
# model = LightningTestModel(hparams)
comet_dir = os.path.join(tmpdir, 'cometruns')
# We test CometLogger in offline mode with local saves
logger = CometLogger(
save_dir=comet_dir,
project_name='general',
workspace='dummy-test',
)
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
)
trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})

View File

@ -1,54 +1,9 @@
import os
import pickle
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger
from tests.base import LightningTestModel
def test_mlflow_logger(tmpdir):
def test_mlflow_logger_exists(tmpdir):
"""Verify that basic functionality of mlflow logger works."""
tutils.reset_seed()
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
mlflow_dir = os.path.join(tmpdir, 'mlruns')
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
logger = MLFlowLogger('test', save_dir=tmpdir)
# Test already exists
logger2 = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
_ = logger2.run_id
# Try logging string
logger.log_metrics({'acc': 'test'})
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, 'Training failed'
def test_mlflow_pickle(tmpdir):
"""Verify that pickling trainer with mlflow logger works."""
tutils.reset_seed()
mlflow_dir = os.path.join(tmpdir, 'mlruns')
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
)
trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})
logger2 = MLFlowLogger('test', save_dir=tmpdir)
assert logger.run_id != logger2.run_id

View File

@ -1,4 +1,3 @@
import pickle
from unittest.mock import patch, MagicMock
import torch
@ -9,29 +8,9 @@ from pytorch_lightning.loggers import NeptuneLogger
from tests.base import LightningTestModel
def test_neptune_logger(tmpdir):
"""Verify that basic functionality of neptune logger works."""
tutils.reset_seed()
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
logger = NeptuneLogger(offline_mode=True)
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, 'Training failed'
@patch('pytorch_lightning.loggers.neptune.neptune')
def test_neptune_online(neptune):
logger = NeptuneLogger(api_key='test', project_name='project')
logger = NeptuneLogger(api_key='test', offline_mode=False, project_name='project')
neptune.init.assert_called_once_with(api_token='test', project_qualified_name='project')
assert logger.name == neptune.create_experiment().name
@ -80,24 +59,6 @@ def test_neptune_additional_methods(neptune):
neptune.create_experiment().append_tags.assert_called_once_with('two', 'tags')
def test_neptune_pickle(tmpdir):
"""Verify that pickling trainer with neptune logger works."""
tutils.reset_seed()
logger = NeptuneLogger(offline_mode=True)
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
logger=logger
)
trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})
def test_neptune_leave_open_experiment_after_fit(tmpdir):
"""Verify that neptune experiment was closed after training"""
tutils.reset_seed()
@ -121,6 +82,5 @@ def test_neptune_leave_open_experiment_after_fit(tmpdir):
logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
assert logger_close_after_fit._experiment.stop.call_count == 1
logger_open_after_fit = _run_training(
NeptuneLogger(offline_mode=True, close_after_fit=False))
logger_open_after_fit = _run_training(NeptuneLogger(offline_mode=True, close_after_fit=False))
assert logger_open_after_fit._experiment.stop.call_count == 0

View File

@ -1,43 +1,9 @@
import pickle
from argparse import Namespace
import pytest
import torch
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import LightningTestModel
def test_tensorboard_logger(tmpdir):
"""Verify that basic functionality of Tensorboard logger works."""
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
logger = TensorBoardLogger(save_dir=tmpdir, name="tensorboard_logger_test")
trainer_options = dict(max_epochs=1, train_percent_check=0.01, logger=logger)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
print("result finished")
assert result == 1, "Training failed"
def test_tensorboard_pickle(tmpdir):
"""Verify that pickling trainer with Tensorboard logger works."""
logger = TensorBoardLogger(save_dir=tmpdir, name="tensorboard_pickle_test")
trainer_options = dict(max_epochs=1, logger=logger)
trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})
def test_tensorboard_automatic_versioning(tmpdir):
@ -79,13 +45,10 @@ def test_tensorboard_named_version(tmpdir):
# in the "minimum requirements" test setup
def test_tensorboard_no_name(tmpdir):
@pytest.mark.parametrize("name", ['', None])
def test_tensorboard_no_name(tmpdir, name):
"""Verify that None or empty name works"""
logger = TensorBoardLogger(save_dir=tmpdir, name="")
assert logger.root_dir == tmpdir
logger = TensorBoardLogger(save_dir=tmpdir, name=None)
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
assert logger.root_dir == tmpdir

View File

@ -1,51 +0,0 @@
import pickle
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from tests.base import LightningTestModel
def test_testtube_logger(tmpdir):
"""Verify that basic functionality of test tube logger works."""
tutils.reset_seed()
hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)
logger = tutils.get_default_testtube_logger(tmpdir, False)
assert logger.name == 'lightning_logs'
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, 'Training failed'
def test_testtube_pickle(tmpdir):
"""Verify that pickling a trainer containing a test tube logger works."""
tutils.reset_seed()
hparams = tutils.get_default_hparams()
logger = tutils.get_default_testtube_logger(tmpdir, False)
logger.log_hyperparams(hparams)
logger.save()
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
train_percent_check=0.05,
logger=logger
)
trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})

View File

@ -2,8 +2,6 @@ import os
import pickle
from unittest.mock import patch
import pytest
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
@ -37,7 +35,9 @@ def test_wandb_logger(wandb):
@patch('pytorch_lightning.loggers.wandb.wandb')
def test_wandb_pickle(wandb):
"""Verify that pickling trainer with wandb logger works.
Wandb doesn't work well with pytest so we have to mock it out here."""
Wandb doesn't work well with pytest so we have to mock it out here.
"""
tutils.reset_seed()
class Experiment: