Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) (#1379)

* Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI)

* fix

* test ci

* debug

* debug CI

* Fix CircleCI

* Fix Any CI environment switch to bypass mode

* Removed debug prints

* Improve code coverage

* Improve code coverage

* Reverted

* Improve code coverage

* Test CI

* test codecov

* Codecov fix

* remove pragma

Co-authored-by: bmartinn <>
This commit is contained in:
Martin.B 2020-04-08 18:52:52 +03:00 committed by GitHub
parent 2ae2bd2b46
commit fb8d085b5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 59 additions and 22 deletions

View File

@ -24,6 +24,7 @@ Use the logger anywhere in you LightningModule as follows:
"""
from argparse import Namespace
from os import environ
from pathlib import Path
from typing import Any, Dict, Optional, Union
@ -58,9 +59,9 @@ class TrainsLogger(LightningLoggerBase):
sent along side the task scalars. Defaults to True.
Examples:
>>> logger = TrainsLogger("lightning_log", "my-test", output_uri=".") # doctest: +ELLIPSIS
>>> logger = TrainsLogger("lightning_log", "my-lightning-test", output_uri=".") # doctest: +ELLIPSIS
TRAINS Task: ...
TRAINS results page: https://demoapp.trains.allegro.ai/.../log
TRAINS results page: ...
>>> logger.log_metrics({"val_loss": 1.23}, step=0)
>>> logger.log_text("sample test")
sample test
@ -69,7 +70,7 @@ class TrainsLogger(LightningLoggerBase):
>>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8))
"""
_bypass = False
_bypass = None
def __init__(
self,
@ -83,8 +84,24 @@ class TrainsLogger(LightningLoggerBase):
auto_resource_monitoring: bool = True
) -> None:
super().__init__()
if self._bypass:
if self.bypass_mode():
self._trains = None
print('TRAINS Task: running in bypass mode')
print('TRAINS results page: disabled')
class _TaskStub(object):
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, attr):
if attr in ('name', 'id'):
return ''
return self
def __setattr__(self, attr, val):
pass
self._trains = _TaskStub()
else:
self._trains = Task.init(
project_name=project_name,
@ -114,8 +131,9 @@ class TrainsLogger(LightningLoggerBase):
"""
ID is a uuid (string) representing this specific experiment in the entire system.
"""
if self._bypass or not self._trains:
if not self._trains:
return None
return self._trains.id
@rank_zero_only
@ -126,8 +144,8 @@ class TrainsLogger(LightningLoggerBase):
params:
The hyperparameters that passed through the model.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return
if not params:
return
@ -147,8 +165,8 @@ class TrainsLogger(LightningLoggerBase):
then the elements will be logged as "title" and "series" respectively.
step: Step number at which the metrics should be recorded. Defaults to None.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return
if not step:
step = self._trains.get_last_iteration()
@ -179,8 +197,8 @@ class TrainsLogger(LightningLoggerBase):
value: The value to log.
step: Step number at which the metrics should be recorded. Defaults to None.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return
if not step:
step = self._trains.get_last_iteration()
@ -197,8 +215,12 @@ class TrainsLogger(LightningLoggerBase):
Args:
text: The value of the log (data-point).
"""
if self._bypass or not self._trains:
return None
if self.bypass_mode():
print(text)
return
if not self._trains:
return
self._trains.get_logger().report_text(text)
@ -222,8 +244,8 @@ class TrainsLogger(LightningLoggerBase):
step:
Step number at which the metrics should be recorded. Defaults to None.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return
if not step:
step = self._trains.get_last_iteration()
@ -265,8 +287,8 @@ class TrainsLogger(LightningLoggerBase):
If True local artifact will be deleted (only applies if artifact_object is a
local file). Defaults to False.
"""
if self._bypass or not self._trains:
return None
if not self._trains:
return
self._trains.upload_artifact(
name=name, artifact_object=artifact, metadata=metadata,
@ -278,8 +300,9 @@ class TrainsLogger(LightningLoggerBase):
@rank_zero_only
def finalize(self, status: str = None) -> None:
if self._bypass or not self._trains:
return None
if self.bypass_mode() or not self._trains:
return
self._trains.close()
self._trains = None
@ -288,14 +311,16 @@ class TrainsLogger(LightningLoggerBase):
"""
Name is a human readable non-unique name (str) of the experiment.
"""
if self._bypass or not self._trains:
if not self._trains:
return ''
return self._trains.name
@property
def version(self) -> Union[str, None]:
if self._bypass or not self._trains:
if not self._trains:
return None
return self._trains.id
@classmethod
@ -327,9 +352,21 @@ class TrainsLogger(LightningLoggerBase):
"""
cls._bypass = bypass
@classmethod
def bypass_mode(cls) -> bool:
"""
bypass_mode returns the bypass mode state.
Notice GITHUB_ACTIONS env will automatically set bypass_mode to True
unless overridden specifically with set_bypass_mode(False)
:return: If True, all outside communication is skipped
"""
return cls._bypass if cls._bypass is not None else bool(environ.get('CI'))
def __getstate__(self) -> Union[str, None]:
if self._bypass or not self._trains:
if self.bypass_mode() or not self._trains:
return ''
return self._trains.id
def __setstate__(self, state: str) -> None: