Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) (#1379)
* Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) * fix * test ci * debug * debug CI * Fix CircleCI * Fix Any CI environment switch to bypass mode * Removed debug prints * Improve code coverage * Improve code coverage * Reverted * Improve code coverage * Test CI * test codecov * Codecov fix * remove pragma Co-authored-by: bmartinn <>
This commit is contained in:
parent
2ae2bd2b46
commit
fb8d085b5f
|
@ -24,6 +24,7 @@ Use the logger anywhere in you LightningModule as follows:
|
|||
|
||||
"""
|
||||
from argparse import Namespace
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
|
@ -58,9 +59,9 @@ class TrainsLogger(LightningLoggerBase):
|
|||
sent along side the task scalars. Defaults to True.
|
||||
|
||||
Examples:
|
||||
>>> logger = TrainsLogger("lightning_log", "my-test", output_uri=".") # doctest: +ELLIPSIS
|
||||
>>> logger = TrainsLogger("lightning_log", "my-lightning-test", output_uri=".") # doctest: +ELLIPSIS
|
||||
TRAINS Task: ...
|
||||
TRAINS results page: https://demoapp.trains.allegro.ai/.../log
|
||||
TRAINS results page: ...
|
||||
>>> logger.log_metrics({"val_loss": 1.23}, step=0)
|
||||
>>> logger.log_text("sample test")
|
||||
sample test
|
||||
|
@ -69,7 +70,7 @@ class TrainsLogger(LightningLoggerBase):
|
|||
>>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8))
|
||||
"""
|
||||
|
||||
_bypass = False
|
||||
_bypass = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -83,8 +84,24 @@ class TrainsLogger(LightningLoggerBase):
|
|||
auto_resource_monitoring: bool = True
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if self._bypass:
|
||||
if self.bypass_mode():
|
||||
self._trains = None
|
||||
print('TRAINS Task: running in bypass mode')
|
||||
print('TRAINS results page: disabled')
|
||||
|
||||
class _TaskStub(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ('name', 'id'):
|
||||
return ''
|
||||
return self
|
||||
|
||||
def __setattr__(self, attr, val):
|
||||
pass
|
||||
|
||||
self._trains = _TaskStub()
|
||||
else:
|
||||
self._trains = Task.init(
|
||||
project_name=project_name,
|
||||
|
@ -114,8 +131,9 @@ class TrainsLogger(LightningLoggerBase):
|
|||
"""
|
||||
ID is a uuid (string) representing this specific experiment in the entire system.
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
if not self._trains:
|
||||
return None
|
||||
|
||||
return self._trains.id
|
||||
|
||||
@rank_zero_only
|
||||
|
@ -126,8 +144,8 @@ class TrainsLogger(LightningLoggerBase):
|
|||
params:
|
||||
The hyperparameters that passed through the model.
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
return None
|
||||
if not self._trains:
|
||||
return
|
||||
if not params:
|
||||
return
|
||||
|
||||
|
@ -147,8 +165,8 @@ class TrainsLogger(LightningLoggerBase):
|
|||
then the elements will be logged as "title" and "series" respectively.
|
||||
step: Step number at which the metrics should be recorded. Defaults to None.
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
return None
|
||||
if not self._trains:
|
||||
return
|
||||
|
||||
if not step:
|
||||
step = self._trains.get_last_iteration()
|
||||
|
@ -179,8 +197,8 @@ class TrainsLogger(LightningLoggerBase):
|
|||
value: The value to log.
|
||||
step: Step number at which the metrics should be recorded. Defaults to None.
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
return None
|
||||
if not self._trains:
|
||||
return
|
||||
|
||||
if not step:
|
||||
step = self._trains.get_last_iteration()
|
||||
|
@ -197,8 +215,12 @@ class TrainsLogger(LightningLoggerBase):
|
|||
Args:
|
||||
text: The value of the log (data-point).
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
return None
|
||||
if self.bypass_mode():
|
||||
print(text)
|
||||
return
|
||||
|
||||
if not self._trains:
|
||||
return
|
||||
|
||||
self._trains.get_logger().report_text(text)
|
||||
|
||||
|
@ -222,8 +244,8 @@ class TrainsLogger(LightningLoggerBase):
|
|||
step:
|
||||
Step number at which the metrics should be recorded. Defaults to None.
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
return None
|
||||
if not self._trains:
|
||||
return
|
||||
|
||||
if not step:
|
||||
step = self._trains.get_last_iteration()
|
||||
|
@ -265,8 +287,8 @@ class TrainsLogger(LightningLoggerBase):
|
|||
If True local artifact will be deleted (only applies if artifact_object is a
|
||||
local file). Defaults to False.
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
return None
|
||||
if not self._trains:
|
||||
return
|
||||
|
||||
self._trains.upload_artifact(
|
||||
name=name, artifact_object=artifact, metadata=metadata,
|
||||
|
@ -278,8 +300,9 @@ class TrainsLogger(LightningLoggerBase):
|
|||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str = None) -> None:
|
||||
if self._bypass or not self._trains:
|
||||
return None
|
||||
if self.bypass_mode() or not self._trains:
|
||||
return
|
||||
|
||||
self._trains.close()
|
||||
self._trains = None
|
||||
|
||||
|
@ -288,14 +311,16 @@ class TrainsLogger(LightningLoggerBase):
|
|||
"""
|
||||
Name is a human readable non-unique name (str) of the experiment.
|
||||
"""
|
||||
if self._bypass or not self._trains:
|
||||
if not self._trains:
|
||||
return ''
|
||||
|
||||
return self._trains.name
|
||||
|
||||
@property
|
||||
def version(self) -> Union[str, None]:
|
||||
if self._bypass or not self._trains:
|
||||
if not self._trains:
|
||||
return None
|
||||
|
||||
return self._trains.id
|
||||
|
||||
@classmethod
|
||||
|
@ -327,9 +352,21 @@ class TrainsLogger(LightningLoggerBase):
|
|||
"""
|
||||
cls._bypass = bypass
|
||||
|
||||
@classmethod
|
||||
def bypass_mode(cls) -> bool:
|
||||
"""
|
||||
bypass_mode returns the bypass mode state.
|
||||
Notice GITHUB_ACTIONS env will automatically set bypass_mode to True
|
||||
unless overridden specifically with set_bypass_mode(False)
|
||||
|
||||
:return: If True, all outside communication is skipped
|
||||
"""
|
||||
return cls._bypass if cls._bypass is not None else bool(environ.get('CI'))
|
||||
|
||||
def __getstate__(self) -> Union[str, None]:
|
||||
if self._bypass or not self._trains:
|
||||
if self.bypass_mode() or not self._trains:
|
||||
return ''
|
||||
|
||||
return self._trains.id
|
||||
|
||||
def __setstate__(self, state: str) -> None:
|
||||
|
|
Loading…
Reference in New Issue