Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) (#1379)

* Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI)

* fix

* test ci

* debug

* debug CI

* Fix CircleCI

* Fix Any CI environment switch to bypass mode

* Removed debug prints

* Improve code coverage

* Improve code coverage

* Reverted

* Improve code coverage

* Test CI

* test codecov

* Codecov fix

* remove pragma

Co-authored-by: bmartinn <>
This commit is contained in:
Martin.B 2020-04-08 18:52:52 +03:00 committed by GitHub
parent 2ae2bd2b46
commit fb8d085b5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 59 additions and 22 deletions

View File

@ -24,6 +24,7 @@ Use the logger anywhere in you LightningModule as follows:
""" """
from argparse import Namespace from argparse import Namespace
from os import environ
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
@ -58,9 +59,9 @@ class TrainsLogger(LightningLoggerBase):
sent along side the task scalars. Defaults to True. sent along side the task scalars. Defaults to True.
Examples: Examples:
>>> logger = TrainsLogger("lightning_log", "my-test", output_uri=".") # doctest: +ELLIPSIS >>> logger = TrainsLogger("lightning_log", "my-lightning-test", output_uri=".") # doctest: +ELLIPSIS
TRAINS Task: ... TRAINS Task: ...
TRAINS results page: https://demoapp.trains.allegro.ai/.../log TRAINS results page: ...
>>> logger.log_metrics({"val_loss": 1.23}, step=0) >>> logger.log_metrics({"val_loss": 1.23}, step=0)
>>> logger.log_text("sample test") >>> logger.log_text("sample test")
sample test sample test
@ -69,7 +70,7 @@ class TrainsLogger(LightningLoggerBase):
>>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8)) >>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8))
""" """
_bypass = False _bypass = None
def __init__( def __init__(
self, self,
@ -83,8 +84,24 @@ class TrainsLogger(LightningLoggerBase):
auto_resource_monitoring: bool = True auto_resource_monitoring: bool = True
) -> None: ) -> None:
super().__init__() super().__init__()
if self._bypass: if self.bypass_mode():
self._trains = None self._trains = None
print('TRAINS Task: running in bypass mode')
print('TRAINS results page: disabled')
class _TaskStub(object):
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, attr):
if attr in ('name', 'id'):
return ''
return self
def __setattr__(self, attr, val):
pass
self._trains = _TaskStub()
else: else:
self._trains = Task.init( self._trains = Task.init(
project_name=project_name, project_name=project_name,
@ -114,8 +131,9 @@ class TrainsLogger(LightningLoggerBase):
""" """
ID is a uuid (string) representing this specific experiment in the entire system. ID is a uuid (string) representing this specific experiment in the entire system.
""" """
if self._bypass or not self._trains: if not self._trains:
return None return None
return self._trains.id return self._trains.id
@rank_zero_only @rank_zero_only
@ -126,8 +144,8 @@ class TrainsLogger(LightningLoggerBase):
params: params:
The hyperparameters that passed through the model. The hyperparameters that passed through the model.
""" """
if self._bypass or not self._trains: if not self._trains:
return None return
if not params: if not params:
return return
@ -147,8 +165,8 @@ class TrainsLogger(LightningLoggerBase):
then the elements will be logged as "title" and "series" respectively. then the elements will be logged as "title" and "series" respectively.
step: Step number at which the metrics should be recorded. Defaults to None. step: Step number at which the metrics should be recorded. Defaults to None.
""" """
if self._bypass or not self._trains: if not self._trains:
return None return
if not step: if not step:
step = self._trains.get_last_iteration() step = self._trains.get_last_iteration()
@ -179,8 +197,8 @@ class TrainsLogger(LightningLoggerBase):
value: The value to log. value: The value to log.
step: Step number at which the metrics should be recorded. Defaults to None. step: Step number at which the metrics should be recorded. Defaults to None.
""" """
if self._bypass or not self._trains: if not self._trains:
return None return
if not step: if not step:
step = self._trains.get_last_iteration() step = self._trains.get_last_iteration()
@ -197,8 +215,12 @@ class TrainsLogger(LightningLoggerBase):
Args: Args:
text: The value of the log (data-point). text: The value of the log (data-point).
""" """
if self._bypass or not self._trains: if self.bypass_mode():
return None print(text)
return
if not self._trains:
return
self._trains.get_logger().report_text(text) self._trains.get_logger().report_text(text)
@ -222,8 +244,8 @@ class TrainsLogger(LightningLoggerBase):
step: step:
Step number at which the metrics should be recorded. Defaults to None. Step number at which the metrics should be recorded. Defaults to None.
""" """
if self._bypass or not self._trains: if not self._trains:
return None return
if not step: if not step:
step = self._trains.get_last_iteration() step = self._trains.get_last_iteration()
@ -265,8 +287,8 @@ class TrainsLogger(LightningLoggerBase):
If True local artifact will be deleted (only applies if artifact_object is a If True local artifact will be deleted (only applies if artifact_object is a
local file). Defaults to False. local file). Defaults to False.
""" """
if self._bypass or not self._trains: if not self._trains:
return None return
self._trains.upload_artifact( self._trains.upload_artifact(
name=name, artifact_object=artifact, metadata=metadata, name=name, artifact_object=artifact, metadata=metadata,
@ -278,8 +300,9 @@ class TrainsLogger(LightningLoggerBase):
@rank_zero_only @rank_zero_only
def finalize(self, status: str = None) -> None: def finalize(self, status: str = None) -> None:
if self._bypass or not self._trains: if self.bypass_mode() or not self._trains:
return None return
self._trains.close() self._trains.close()
self._trains = None self._trains = None
@ -288,14 +311,16 @@ class TrainsLogger(LightningLoggerBase):
""" """
Name is a human readable non-unique name (str) of the experiment. Name is a human readable non-unique name (str) of the experiment.
""" """
if self._bypass or not self._trains: if not self._trains:
return '' return ''
return self._trains.name return self._trains.name
@property @property
def version(self) -> Union[str, None]: def version(self) -> Union[str, None]:
if self._bypass or not self._trains: if not self._trains:
return None return None
return self._trains.id return self._trains.id
@classmethod @classmethod
@ -327,9 +352,21 @@ class TrainsLogger(LightningLoggerBase):
""" """
cls._bypass = bypass cls._bypass = bypass
@classmethod
def bypass_mode(cls) -> bool:
"""
bypass_mode returns the bypass mode state.
Notice GITHUB_ACTIONS env will automatically set bypass_mode to True
unless overridden specifically with set_bypass_mode(False)
:return: If True, all outside communication is skipped
"""
return cls._bypass if cls._bypass is not None else bool(environ.get('CI'))
def __getstate__(self) -> Union[str, None]: def __getstate__(self) -> Union[str, None]:
if self._bypass or not self._trains: if self.bypass_mode() or not self._trains:
return '' return ''
return self._trains.id return self._trains.id
def __setstate__(self, state: str) -> None: def __setstate__(self, state: str) -> None: