fix flushing loggers (#1459)
* flushing loggers * flushing loggers * flushing loggers * flushing loggers * changelog * typo * fix trains * optimize imports * add logger test all * add logger test pickle * flake8 * fix benchmark * hanging loggers * try * del * all * cleaning
This commit is contained in:
parent
c96c6a6b33
commit
b3fe17ddeb
|
@ -28,7 +28,7 @@ jobs:
|
|||
requires: 'minimal'
|
||||
|
||||
# Timeout: https://stackoverflow.com/a/59076067/4521646
|
||||
timeout-minutes: 30
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
|
|
@ -34,7 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
-
|
||||
- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ class LightningLoggerBase(ABC):
|
|||
`LightningLoggerBase.agg_and_log_metrics` method.
|
||||
"""
|
||||
self._rank = 0
|
||||
self._prev_step = -1
|
||||
self._prev_step: int = -1
|
||||
self._metrics_to_agg: List[Dict[str, float]] = []
|
||||
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
|
||||
self._agg_default_func = agg_default_func
|
||||
|
@ -98,15 +98,15 @@ class LightningLoggerBase(ABC):
|
|||
return step, None
|
||||
|
||||
# compute the metrics
|
||||
agg_step, agg_mets = self._finalize_agg_metrics()
|
||||
agg_step, agg_mets = self._reduce_agg_metrics()
|
||||
|
||||
# as new step received reset accumulator
|
||||
self._metrics_to_agg = [metrics]
|
||||
self._prev_step = step
|
||||
return agg_step, agg_mets
|
||||
|
||||
def _finalize_agg_metrics(self):
|
||||
"""Aggregate accumulated metrics. This shall be called in close."""
|
||||
def _reduce_agg_metrics(self):
|
||||
"""Aggregate accumulated metrics."""
|
||||
# compute the metrics
|
||||
if not self._metrics_to_agg:
|
||||
agg_mets = None
|
||||
|
@ -116,6 +116,14 @@ class LightningLoggerBase(ABC):
|
|||
agg_mets = merge_dicts(self._metrics_to_agg, self._agg_key_funcs, self._agg_default_func)
|
||||
return self._prev_step, agg_mets
|
||||
|
||||
def _finalize_agg_metrics(self):
|
||||
"""This shall be called before save/close."""
|
||||
agg_step, metrics_to_log = self._reduce_agg_metrics()
|
||||
self._metrics_to_agg = []
|
||||
|
||||
if metrics_to_log is not None:
|
||||
self.log_metrics(metrics=metrics_to_log, step=agg_step)
|
||||
|
||||
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
|
||||
"""Aggregates and records metrics.
|
||||
This method doesn't log the passed metrics instantaneously, but instead
|
||||
|
@ -219,7 +227,7 @@ class LightningLoggerBase(ABC):
|
|||
|
||||
def save(self) -> None:
|
||||
"""Save log data."""
|
||||
pass
|
||||
self._finalize_agg_metrics()
|
||||
|
||||
def finalize(self, status: str) -> None:
|
||||
"""Do any processing that is necessary to finalize an experiment.
|
||||
|
@ -227,14 +235,11 @@ class LightningLoggerBase(ABC):
|
|||
Args:
|
||||
status: Status that the experiment finished with (e.g. success, failed, aborted)
|
||||
"""
|
||||
pass
|
||||
self.save()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Do any cleanup that is necessary to close an experiment."""
|
||||
agg_step, metrics_to_log = self._finalize_agg_metrics()
|
||||
|
||||
if metrics_to_log is not None:
|
||||
self.log_metrics(metrics=metrics_to_log, step=agg_step)
|
||||
self.save()
|
||||
|
||||
@property
|
||||
def rank(self) -> int:
|
||||
|
@ -292,7 +297,6 @@ class LoggerCollection(LightningLoggerBase):
|
|||
|
||||
@LightningLoggerBase.rank.setter
|
||||
def rank(self, value: int) -> None:
|
||||
self._rank = value
|
||||
for logger in self._logger_iterable:
|
||||
logger.rank = value
|
||||
|
||||
|
|
|
@ -36,10 +36,15 @@ class CometLogger(LightningLoggerBase):
|
|||
Log using `comet.ml <https://www.comet.ml>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None,
|
||||
workspace: Optional[str] = None, project_name: Optional[str] = None,
|
||||
rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None,
|
||||
experiment_key: Optional[str] = None, **kwargs):
|
||||
def __init__(self,
|
||||
api_key: Optional[str] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
workspace: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
rest_api_key: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
experiment_key: Optional[str] = None,
|
||||
**kwargs):
|
||||
r"""
|
||||
|
||||
Requires either an API Key (online mode) or a local directory path (offline mode)
|
||||
|
@ -118,6 +123,7 @@ class CometLogger(LightningLoggerBase):
|
|||
self.name = experiment_name
|
||||
except TypeError as e:
|
||||
log.exception("Failed to set experiment name for comet.ml logger")
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def experiment(self) -> CometBaseExperiment:
|
||||
|
@ -197,7 +203,7 @@ class CometLogger(LightningLoggerBase):
|
|||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.experiment.project_name
|
||||
return str(self.experiment.project_name)
|
||||
|
||||
@name.setter
|
||||
def name(self, value: str) -> None:
|
||||
|
|
|
@ -23,6 +23,7 @@ Use the logger anywhere in you LightningModule as follows:
|
|||
self.logger.experiment.whatever_ml_flow_supports(...)
|
||||
|
||||
"""
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from time import time
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
@ -39,10 +40,14 @@ from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
|
|||
|
||||
|
||||
class MLFlowLogger(LightningLoggerBase):
|
||||
def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None,
|
||||
tags: Dict[str, Any] = None):
|
||||
r"""
|
||||
"""MLFLow logger"""
|
||||
|
||||
def __init__(self,
|
||||
experiment_name: str = 'default',
|
||||
tracking_uri: Optional[str] = None,
|
||||
tags: Optional[Dict[str, Any]] = None,
|
||||
save_dir: Optional[str] = None):
|
||||
r"""
|
||||
Logs using MLFlow
|
||||
|
||||
Args:
|
||||
|
@ -51,6 +56,8 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
tags (dict): todo this param
|
||||
"""
|
||||
super().__init__()
|
||||
if not tracking_uri and save_dir:
|
||||
tracking_uri = f'file:{os.sep * 2}{save_dir}'
|
||||
self._mlflow_client = MlflowClient(tracking_uri)
|
||||
self.experiment_name = experiment_name
|
||||
self._run_id = None
|
||||
|
@ -59,7 +66,6 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
@property
|
||||
def experiment(self) -> MlflowClient:
|
||||
r"""
|
||||
|
||||
Actual mlflow object. To use mlflow features do the following.
|
||||
|
||||
Example::
|
||||
|
@ -102,11 +108,9 @@ class MLFlowLogger(LightningLoggerBase):
|
|||
continue
|
||||
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str = 'FINISHED') -> None:
|
||||
super().finalize(status)
|
||||
if status == 'success':
|
||||
status = 'FINISHED'
|
||||
self.experiment.set_terminated(self.run_id, status)
|
||||
|
|
|
@ -29,13 +29,18 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
To log experiment data in online mode, NeptuneLogger requries an API key:
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, project_name: Optional[str] = None,
|
||||
close_after_fit: Optional[bool] = True, offline_mode: bool = False,
|
||||
def __init__(self,
|
||||
api_key: Optional[str] = None,
|
||||
project_name: Optional[str] = None,
|
||||
close_after_fit: Optional[bool] = True,
|
||||
offline_mode: bool = True,
|
||||
experiment_name: Optional[str] = None,
|
||||
upload_source_files: Optional[List[str]] = None, params: Optional[Dict[str, Any]] = None,
|
||||
properties: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, **kwargs):
|
||||
upload_source_files: Optional[List[str]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs):
|
||||
r"""
|
||||
|
||||
Initialize a neptune.ai logger.
|
||||
|
||||
.. note:: Requires either an API Key (online mode) or a local directory path (offline mode)
|
||||
|
@ -135,8 +140,8 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
"namespace/project_name" for example "tom/minst-classification".
|
||||
If None, the value of NEPTUNE_PROJECT environment variable will be taken.
|
||||
You need to create the project in https://neptune.ai first.
|
||||
offline_mode: Optional default False. If offline_mode=True no logs will be send
|
||||
to neptune. Usually used for debug purposes.
|
||||
offline_mode: Optional default True. If offline_mode=True no logs will be send
|
||||
to neptune. Usually used for debug and test purposes.
|
||||
close_after_fit: Optional default True. If close_after_fit=False the experiment
|
||||
will not be closed after training and additional metrics,
|
||||
images or artifacts can be logged. Also, remember to close the experiment explicitly
|
||||
|
@ -243,6 +248,7 @@ class NeptuneLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
super().finalize(status)
|
||||
if self.close_after_fit:
|
||||
self.experiment.stop()
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@ from pytorch_lightning import _logger as log
|
|||
|
||||
class TensorBoardLogger(LightningLoggerBase):
|
||||
r"""
|
||||
|
||||
Log to local file system in TensorBoard format
|
||||
|
||||
Implemented using :class:`torch.utils.tensorboard.SummaryWriter`. Logs are saved to
|
||||
|
@ -40,10 +39,11 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
"""
|
||||
NAME_CSV_TAGS = 'meta_tags.csv'
|
||||
|
||||
def __init__(
|
||||
self, save_dir: str, name: Optional[str] = "default",
|
||||
version: Optional[Union[int, str]] = None, **kwargs
|
||||
):
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
name: Optional[str] = "default",
|
||||
version: Optional[Union[int, str]] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.save_dir = save_dir
|
||||
self._name = name
|
||||
|
@ -51,7 +51,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
|
||||
self._experiment = None
|
||||
self.tags = {}
|
||||
self.kwargs = kwargs
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
def root_dir(self) -> str:
|
||||
|
@ -92,7 +92,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
return self._experiment
|
||||
|
||||
os.makedirs(self.root_dir, exist_ok=True)
|
||||
self._experiment = SummaryWriter(log_dir=self.log_dir, **self.kwargs)
|
||||
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
|
||||
return self._experiment
|
||||
|
||||
@rank_zero_only
|
||||
|
@ -127,6 +127,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def save(self) -> None:
|
||||
super().save()
|
||||
try:
|
||||
self.experiment.flush()
|
||||
except AttributeError:
|
||||
|
|
|
@ -18,10 +18,13 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
|
||||
__test__ = False
|
||||
|
||||
def __init__(
|
||||
self, save_dir: str, name: str = "default", description: Optional[str] = None,
|
||||
debug: bool = False, version: Optional[int] = None, create_git_tag: bool = False
|
||||
):
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
name: str = "default",
|
||||
description: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
version: Optional[int] = None,
|
||||
create_git_tag: bool = False):
|
||||
r"""
|
||||
|
||||
Example
|
||||
|
@ -105,12 +108,14 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def save(self) -> None:
|
||||
super().save()
|
||||
# TODO: HACK figure out where this is being set to true
|
||||
self.experiment.debug = self.debug
|
||||
self.experiment.save()
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
super().finalize(status)
|
||||
# TODO: HACK figure out where this is being set to true
|
||||
self.experiment.debug = self.debug
|
||||
self.save()
|
||||
|
@ -118,6 +123,7 @@ class TestTubeLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def close(self) -> None:
|
||||
super().save()
|
||||
# TODO: HACK figure out where this is being set to true
|
||||
self.experiment.debug = self.debug
|
||||
if not self.debug:
|
||||
|
|
|
@ -295,11 +295,9 @@ class TrainsLogger(LightningLoggerBase):
|
|||
delete_after_upload=delete_after_upload
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
pass
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str = None) -> None:
|
||||
# super().finalize(status)
|
||||
if self.bypass_mode() or not self._trains:
|
||||
return
|
||||
|
||||
|
|
|
@ -46,11 +46,18 @@ class WandbLogger(LightningLoggerBase):
|
|||
trainer = Trainer(logger=wandb_logger)
|
||||
"""
|
||||
|
||||
def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None,
|
||||
offline: bool = False, id: Optional[str] = None, anonymous: bool = False,
|
||||
version: Optional[str] = None, project: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None, log_model: bool = False,
|
||||
experiment=None, entity=None):
|
||||
def __init__(self,
|
||||
name: Optional[str] = None,
|
||||
save_dir: Optional[str] = None,
|
||||
offline: bool = False,
|
||||
id: Optional[str] = None,
|
||||
anonymous: bool = False,
|
||||
version: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
log_model: bool = False,
|
||||
experiment=None,
|
||||
entity=None):
|
||||
super().__init__()
|
||||
self._name = name
|
||||
self._save_dir = save_dir
|
||||
|
|
|
@ -370,14 +370,13 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# run evaluation
|
||||
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
|
||||
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
|
||||
eval_results)
|
||||
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)
|
||||
|
||||
# add metrics to prog bar
|
||||
self.add_tqdm_metrics(prog_bar_metrics)
|
||||
|
||||
# log results of test
|
||||
if test_mode and self.proc_rank == 0 and len(callback_metrics) > 0:
|
||||
if test_mode and self.proc_rank == 0:
|
||||
print('-' * 80)
|
||||
print('TEST RESULTS')
|
||||
pprint(callback_metrics)
|
||||
|
|
|
@ -293,8 +293,7 @@ class Trainer(
|
|||
|
||||
# benchmarking
|
||||
self.benchmark = benchmark
|
||||
if benchmark:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.benchmark = self.benchmark
|
||||
|
||||
# Transfer params
|
||||
self.num_nodes = num_nodes
|
||||
|
|
|
@ -89,8 +89,8 @@ class LightValidationMixin(LightValidationStepMixin):
|
|||
val_loss_mean /= len(outputs)
|
||||
val_acc_mean /= len(outputs)
|
||||
|
||||
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
results = {'progress_bar': tqdm_dict, 'log': tqdm_dict}
|
||||
metrics_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
|
||||
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
|
||||
return results
|
||||
|
||||
|
||||
|
@ -355,8 +355,8 @@ class LightTestMixin(LightTestStepMixin):
|
|||
test_loss_mean /= len(outputs)
|
||||
test_acc_mean /= len(outputs)
|
||||
|
||||
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
result = {'progress_bar': tqdm_dict}
|
||||
metrics_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
|
||||
result = {'progress_bar': metrics_dict, 'log': metrics_dict}
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
import inspect
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import (
|
||||
TensorBoardLogger, MLFlowLogger, NeptuneLogger, TestTubeLogger, CometLogger)
|
||||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logger_class", [
|
||||
TensorBoardLogger,
|
||||
CometLogger,
|
||||
MLFlowLogger,
|
||||
NeptuneLogger,
|
||||
TestTubeLogger,
|
||||
# TrainsLogger, # TODO: add this one
|
||||
# WandbLogger, # TODO: add this one
|
||||
])
|
||||
def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
|
||||
"""Verify that basic functionality of all loggers."""
|
||||
tutils.reset_seed()
|
||||
|
||||
# prevent comet logger from trying to print at exit, since
|
||||
# pytest's stdout/stderr redirection breaks it
|
||||
import atexit
|
||||
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
class StoreHistoryLogger(logger_class):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.history = []
|
||||
|
||||
def log_metrics(self, metrics, step):
|
||||
super().log_metrics(metrics, step)
|
||||
self.history.append((step, metrics))
|
||||
|
||||
if 'save_dir' in inspect.getfullargspec(logger_class).args:
|
||||
logger = StoreHistoryLogger(save_dir=str(tmpdir))
|
||||
else:
|
||||
logger = StoreHistoryLogger()
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
logger=logger,
|
||||
train_percent_check=0.2,
|
||||
val_percent_check=0.5,
|
||||
fast_dev_run=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
trainer.test()
|
||||
|
||||
log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
|
||||
assert log_metric_names == [(0, ['val_acc', 'val_loss']),
|
||||
(0, ['train_some_val']),
|
||||
(1, ['test_acc', 'test_loss'])]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logger_class", [
|
||||
TensorBoardLogger,
|
||||
CometLogger,
|
||||
MLFlowLogger,
|
||||
NeptuneLogger,
|
||||
TestTubeLogger,
|
||||
# TrainsLogger, # TODO: add this one
|
||||
# WandbLogger, # TODO: add this one
|
||||
])
|
||||
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
|
||||
"""Verify that pickling trainer with logger works."""
|
||||
tutils.reset_seed()
|
||||
|
||||
# prevent comet logger from trying to print at exit, since
|
||||
# pytest's stdout/stderr redirection breaks it
|
||||
import atexit
|
||||
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
||||
|
||||
if 'save_dir' in inspect.getfullargspec(logger_class).args:
|
||||
logger = logger_class(save_dir=str(tmpdir))
|
||||
else:
|
||||
logger = logger_class()
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=1,
|
||||
logger=logger
|
||||
)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({'acc': 1.0})
|
|
@ -1,5 +1,4 @@
|
|||
import pickle
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
@ -59,18 +58,6 @@ class CustomLogger(LightningLoggerBase):
|
|||
return "1"
|
||||
|
||||
|
||||
class StoreHistoryLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.history = {}
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step):
|
||||
if step not in self.history:
|
||||
self.history[step] = {}
|
||||
self.history[step].update(metrics)
|
||||
|
||||
|
||||
def test_custom_logger(tmpdir):
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
@ -175,6 +162,18 @@ def test_adding_step_key(tmpdir):
|
|||
|
||||
def test_with_accumulate_grad_batches():
|
||||
"""Checks if the logging is performed once for `accumulate_grad_batches` steps."""
|
||||
|
||||
class StoreHistoryLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.history = {}
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics, step):
|
||||
if step not in self.history:
|
||||
self.history[step] = {}
|
||||
self.history[step].update(metrics)
|
||||
|
||||
logger = StoreHistoryLogger()
|
||||
|
||||
np.random.seed(42)
|
||||
|
|
|
@ -1,51 +1,9 @@
|
|||
import os
|
||||
import pickle
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import CometLogger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
def test_comet_logger(tmpdir, monkeypatch):
|
||||
"""Verify that basic functionality of Comet.ml logger works."""
|
||||
|
||||
# prevent comet logger from trying to print at exit, since
|
||||
# pytest's stdout/stderr redirection breaks it
|
||||
import atexit
|
||||
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
||||
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
comet_dir = os.path.join(tmpdir, 'cometruns')
|
||||
|
||||
# We test CometLogger in offline mode with local saves
|
||||
logger = CometLogger(
|
||||
save_dir=comet_dir,
|
||||
project_name='general',
|
||||
workspace='dummy-test',
|
||||
)
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.05,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
trainer.logger.log_metrics({'acc': torch.ones(1)})
|
||||
|
||||
assert result == 1, 'Training failed'
|
||||
|
||||
|
||||
def test_comet_logger_online():
|
||||
|
@ -120,37 +78,3 @@ def test_comet_logger_online():
|
|||
)
|
||||
|
||||
api.assert_called_once_with('rest')
|
||||
|
||||
|
||||
def test_comet_pickle(tmpdir, monkeypatch):
|
||||
"""Verify that pickling trainer with comet logger works."""
|
||||
|
||||
# prevent comet logger from trying to print at exit, since
|
||||
# pytest's stdout/stderr redirection breaks it
|
||||
import atexit
|
||||
monkeypatch.setattr(atexit, 'register', lambda _: None)
|
||||
|
||||
tutils.reset_seed()
|
||||
|
||||
# hparams = tutils.get_default_hparams()
|
||||
# model = LightningTestModel(hparams)
|
||||
|
||||
comet_dir = os.path.join(tmpdir, 'cometruns')
|
||||
|
||||
# We test CometLogger in offline mode with local saves
|
||||
logger = CometLogger(
|
||||
save_dir=comet_dir,
|
||||
project_name='general',
|
||||
workspace='dummy-test',
|
||||
)
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({'acc': 1.0})
|
||||
|
|
|
@ -1,54 +1,9 @@
|
|||
import os
|
||||
import pickle
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
def test_mlflow_logger(tmpdir):
|
||||
def test_mlflow_logger_exists(tmpdir):
|
||||
"""Verify that basic functionality of mlflow logger works."""
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
mlflow_dir = os.path.join(tmpdir, 'mlruns')
|
||||
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
|
||||
|
||||
logger = MLFlowLogger('test', save_dir=tmpdir)
|
||||
# Test already exists
|
||||
logger2 = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
|
||||
_ = logger2.run_id
|
||||
|
||||
# Try logging string
|
||||
logger.log_metrics({'acc': 'test'})
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.05,
|
||||
logger=logger
|
||||
)
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'Training failed'
|
||||
|
||||
|
||||
def test_mlflow_pickle(tmpdir):
|
||||
"""Verify that pickling trainer with mlflow logger works."""
|
||||
tutils.reset_seed()
|
||||
|
||||
mlflow_dir = os.path.join(tmpdir, 'mlruns')
|
||||
logger = MLFlowLogger('test', tracking_uri=f'file:{os.sep * 2}{mlflow_dir}')
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({'acc': 1.0})
|
||||
logger2 = MLFlowLogger('test', save_dir=tmpdir)
|
||||
assert logger.run_id != logger2.run_id
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import pickle
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import torch
|
||||
|
@ -9,29 +8,9 @@ from pytorch_lightning.loggers import NeptuneLogger
|
|||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
def test_neptune_logger(tmpdir):
|
||||
"""Verify that basic functionality of neptune logger works."""
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
logger = NeptuneLogger(offline_mode=True)
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.05,
|
||||
logger=logger
|
||||
)
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'Training failed'
|
||||
|
||||
|
||||
@patch('pytorch_lightning.loggers.neptune.neptune')
|
||||
def test_neptune_online(neptune):
|
||||
logger = NeptuneLogger(api_key='test', project_name='project')
|
||||
logger = NeptuneLogger(api_key='test', offline_mode=False, project_name='project')
|
||||
neptune.init.assert_called_once_with(api_token='test', project_qualified_name='project')
|
||||
|
||||
assert logger.name == neptune.create_experiment().name
|
||||
|
@ -80,24 +59,6 @@ def test_neptune_additional_methods(neptune):
|
|||
neptune.create_experiment().append_tags.assert_called_once_with('two', 'tags')
|
||||
|
||||
|
||||
def test_neptune_pickle(tmpdir):
|
||||
"""Verify that pickling trainer with neptune logger works."""
|
||||
tutils.reset_seed()
|
||||
|
||||
logger = NeptuneLogger(offline_mode=True)
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({'acc': 1.0})
|
||||
|
||||
|
||||
def test_neptune_leave_open_experiment_after_fit(tmpdir):
|
||||
"""Verify that neptune experiment was closed after training"""
|
||||
tutils.reset_seed()
|
||||
|
@ -121,6 +82,5 @@ def test_neptune_leave_open_experiment_after_fit(tmpdir):
|
|||
logger_close_after_fit = _run_training(NeptuneLogger(offline_mode=True))
|
||||
assert logger_close_after_fit._experiment.stop.call_count == 1
|
||||
|
||||
logger_open_after_fit = _run_training(
|
||||
NeptuneLogger(offline_mode=True, close_after_fit=False))
|
||||
logger_open_after_fit = _run_training(NeptuneLogger(offline_mode=True, close_after_fit=False))
|
||||
assert logger_open_after_fit._experiment.stop.call_count == 0
|
||||
|
|
|
@ -1,43 +1,9 @@
|
|||
import pickle
|
||||
from argparse import Namespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
def test_tensorboard_logger(tmpdir):
|
||||
"""Verify that basic functionality of Tensorboard logger works."""
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
logger = TensorBoardLogger(save_dir=tmpdir, name="tensorboard_logger_test")
|
||||
|
||||
trainer_options = dict(max_epochs=1, train_percent_check=0.01, logger=logger)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
print("result finished")
|
||||
assert result == 1, "Training failed"
|
||||
|
||||
|
||||
def test_tensorboard_pickle(tmpdir):
|
||||
"""Verify that pickling trainer with Tensorboard logger works."""
|
||||
|
||||
logger = TensorBoardLogger(save_dir=tmpdir, name="tensorboard_pickle_test")
|
||||
|
||||
trainer_options = dict(max_epochs=1, logger=logger)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0})
|
||||
|
||||
|
||||
def test_tensorboard_automatic_versioning(tmpdir):
|
||||
|
@ -79,13 +45,10 @@ def test_tensorboard_named_version(tmpdir):
|
|||
# in the "minimum requirements" test setup
|
||||
|
||||
|
||||
def test_tensorboard_no_name(tmpdir):
|
||||
@pytest.mark.parametrize("name", ['', None])
|
||||
def test_tensorboard_no_name(tmpdir, name):
|
||||
"""Verify that None or empty name works"""
|
||||
|
||||
logger = TensorBoardLogger(save_dir=tmpdir, name="")
|
||||
assert logger.root_dir == tmpdir
|
||||
|
||||
logger = TensorBoardLogger(save_dir=tmpdir, name=None)
|
||||
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
|
||||
assert logger.root_dir == tmpdir
|
||||
|
||||
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
import pickle
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
def test_testtube_logger(tmpdir):
|
||||
"""Verify that basic functionality of test tube logger works."""
|
||||
tutils.reset_seed()
|
||||
hparams = tutils.get_default_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
logger = tutils.get_default_testtube_logger(tmpdir, False)
|
||||
|
||||
assert logger.name == 'lightning_logs'
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.05,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
assert result == 1, 'Training failed'
|
||||
|
||||
|
||||
def test_testtube_pickle(tmpdir):
|
||||
"""Verify that pickling a trainer containing a test tube logger works."""
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = tutils.get_default_hparams()
|
||||
|
||||
logger = tutils.get_default_testtube_logger(tmpdir, False)
|
||||
logger.log_hyperparams(hparams)
|
||||
logger.save()
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
train_percent_check=0.05,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({'acc': 1.0})
|
|
@ -2,8 +2,6 @@ import os
|
|||
import pickle
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
@ -37,7 +35,9 @@ def test_wandb_logger(wandb):
|
|||
@patch('pytorch_lightning.loggers.wandb.wandb')
|
||||
def test_wandb_pickle(wandb):
|
||||
"""Verify that pickling trainer with wandb logger works.
|
||||
Wandb doesn't work well with pytest so we have to mock it out here."""
|
||||
|
||||
Wandb doesn't work well with pytest so we have to mock it out here.
|
||||
"""
|
||||
tutils.reset_seed()
|
||||
|
||||
class Experiment:
|
||||
|
|
Loading…
Reference in New Issue