Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) (#1379)
* Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) * fix * test ci * debug * debug CI * Fix CircleCI * Fix Any CI environment switch to bypass mode * Removed debug prints * Improve code coverage * Improve code coverage * Reverted * Improve code coverage * Test CI * test codecov * Codecov fix * remove pragma Co-authored-by: bmartinn <>
This commit is contained in:
parent
2ae2bd2b46
commit
fb8d085b5f
|
@ -24,6 +24,7 @@ Use the logger anywhere in you LightningModule as follows:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from os import environ
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
|
@ -58,9 +59,9 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
sent along side the task scalars. Defaults to True.
|
sent along side the task scalars. Defaults to True.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> logger = TrainsLogger("lightning_log", "my-test", output_uri=".") # doctest: +ELLIPSIS
|
>>> logger = TrainsLogger("lightning_log", "my-lightning-test", output_uri=".") # doctest: +ELLIPSIS
|
||||||
TRAINS Task: ...
|
TRAINS Task: ...
|
||||||
TRAINS results page: https://demoapp.trains.allegro.ai/.../log
|
TRAINS results page: ...
|
||||||
>>> logger.log_metrics({"val_loss": 1.23}, step=0)
|
>>> logger.log_metrics({"val_loss": 1.23}, step=0)
|
||||||
>>> logger.log_text("sample test")
|
>>> logger.log_text("sample test")
|
||||||
sample test
|
sample test
|
||||||
|
@ -69,7 +70,7 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
>>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8))
|
>>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_bypass = False
|
_bypass = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -83,8 +84,24 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
auto_resource_monitoring: bool = True
|
auto_resource_monitoring: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if self._bypass:
|
if self.bypass_mode():
|
||||||
self._trains = None
|
self._trains = None
|
||||||
|
print('TRAINS Task: running in bypass mode')
|
||||||
|
print('TRAINS results page: disabled')
|
||||||
|
|
||||||
|
class _TaskStub(object):
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in ('name', 'id'):
|
||||||
|
return ''
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __setattr__(self, attr, val):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._trains = _TaskStub()
|
||||||
else:
|
else:
|
||||||
self._trains = Task.init(
|
self._trains = Task.init(
|
||||||
project_name=project_name,
|
project_name=project_name,
|
||||||
|
@ -114,8 +131,9 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
"""
|
"""
|
||||||
ID is a uuid (string) representing this specific experiment in the entire system.
|
ID is a uuid (string) representing this specific experiment in the entire system.
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._trains.id
|
return self._trains.id
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
|
@ -126,8 +144,8 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
params:
|
params:
|
||||||
The hyperparameters that passed through the model.
|
The hyperparameters that passed through the model.
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return None
|
return
|
||||||
if not params:
|
if not params:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -147,8 +165,8 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
then the elements will be logged as "title" and "series" respectively.
|
then the elements will be logged as "title" and "series" respectively.
|
||||||
step: Step number at which the metrics should be recorded. Defaults to None.
|
step: Step number at which the metrics should be recorded. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return None
|
return
|
||||||
|
|
||||||
if not step:
|
if not step:
|
||||||
step = self._trains.get_last_iteration()
|
step = self._trains.get_last_iteration()
|
||||||
|
@ -179,8 +197,8 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
value: The value to log.
|
value: The value to log.
|
||||||
step: Step number at which the metrics should be recorded. Defaults to None.
|
step: Step number at which the metrics should be recorded. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return None
|
return
|
||||||
|
|
||||||
if not step:
|
if not step:
|
||||||
step = self._trains.get_last_iteration()
|
step = self._trains.get_last_iteration()
|
||||||
|
@ -197,8 +215,12 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
Args:
|
Args:
|
||||||
text: The value of the log (data-point).
|
text: The value of the log (data-point).
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if self.bypass_mode():
|
||||||
return None
|
print(text)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._trains:
|
||||||
|
return
|
||||||
|
|
||||||
self._trains.get_logger().report_text(text)
|
self._trains.get_logger().report_text(text)
|
||||||
|
|
||||||
|
@ -222,8 +244,8 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
step:
|
step:
|
||||||
Step number at which the metrics should be recorded. Defaults to None.
|
Step number at which the metrics should be recorded. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return None
|
return
|
||||||
|
|
||||||
if not step:
|
if not step:
|
||||||
step = self._trains.get_last_iteration()
|
step = self._trains.get_last_iteration()
|
||||||
|
@ -265,8 +287,8 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
If True local artifact will be deleted (only applies if artifact_object is a
|
If True local artifact will be deleted (only applies if artifact_object is a
|
||||||
local file). Defaults to False.
|
local file). Defaults to False.
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return None
|
return
|
||||||
|
|
||||||
self._trains.upload_artifact(
|
self._trains.upload_artifact(
|
||||||
name=name, artifact_object=artifact, metadata=metadata,
|
name=name, artifact_object=artifact, metadata=metadata,
|
||||||
|
@ -278,8 +300,9 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def finalize(self, status: str = None) -> None:
|
def finalize(self, status: str = None) -> None:
|
||||||
if self._bypass or not self._trains:
|
if self.bypass_mode() or not self._trains:
|
||||||
return None
|
return
|
||||||
|
|
||||||
self._trains.close()
|
self._trains.close()
|
||||||
self._trains = None
|
self._trains = None
|
||||||
|
|
||||||
|
@ -288,14 +311,16 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
"""
|
"""
|
||||||
Name is a human readable non-unique name (str) of the experiment.
|
Name is a human readable non-unique name (str) of the experiment.
|
||||||
"""
|
"""
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
return self._trains.name
|
return self._trains.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> Union[str, None]:
|
def version(self) -> Union[str, None]:
|
||||||
if self._bypass or not self._trains:
|
if not self._trains:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self._trains.id
|
return self._trains.id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -327,9 +352,21 @@ class TrainsLogger(LightningLoggerBase):
|
||||||
"""
|
"""
|
||||||
cls._bypass = bypass
|
cls._bypass = bypass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def bypass_mode(cls) -> bool:
|
||||||
|
"""
|
||||||
|
bypass_mode returns the bypass mode state.
|
||||||
|
Notice GITHUB_ACTIONS env will automatically set bypass_mode to True
|
||||||
|
unless overridden specifically with set_bypass_mode(False)
|
||||||
|
|
||||||
|
:return: If True, all outside communication is skipped
|
||||||
|
"""
|
||||||
|
return cls._bypass if cls._bypass is not None else bool(environ.get('CI'))
|
||||||
|
|
||||||
def __getstate__(self) -> Union[str, None]:
|
def __getstate__(self) -> Union[str, None]:
|
||||||
if self._bypass or not self._trains:
|
if self.bypass_mode() or not self._trains:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
return self._trains.id
|
return self._trains.id
|
||||||
|
|
||||||
def __setstate__(self, state: str) -> None:
|
def __setstate__(self, state: str) -> None:
|
||||||
|
|
Loading…
Reference in New Issue