From fb8d085b5fc2f858ed6740fcc366ddebab3b6e8e Mon Sep 17 00:00:00 2001 From: "Martin.B" <51887611+bmartinn@users.noreply.github.com> Date: Wed, 8 Apr 2020 18:52:52 +0300 Subject: [PATCH] Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) (#1379) * Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) * fix * test ci * debug * debug CI * Fix CircleCI * Fix Any CI environment switch to bypass mode * Removed debug prints * Improve code coverage * Improve code coverage * Reverted * Improve code coverage * Test CI * test codecov * Codecov fix * remove pragma Co-authored-by: bmartinn <> --- pytorch_lightning/loggers/trains.py | 81 +++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/loggers/trains.py b/pytorch_lightning/loggers/trains.py index 55fc0abbf7..3f890bba27 100644 --- a/pytorch_lightning/loggers/trains.py +++ b/pytorch_lightning/loggers/trains.py @@ -24,6 +24,7 @@ Use the logger anywhere in you LightningModule as follows: """ from argparse import Namespace +from os import environ from pathlib import Path from typing import Any, Dict, Optional, Union @@ -58,9 +59,9 @@ class TrainsLogger(LightningLoggerBase): sent along side the task scalars. Defaults to True. Examples: - >>> logger = TrainsLogger("lightning_log", "my-test", output_uri=".") # doctest: +ELLIPSIS + >>> logger = TrainsLogger("lightning_log", "my-lightning-test", output_uri=".") # doctest: +ELLIPSIS TRAINS Task: ... - TRAINS results page: https://demoapp.trains.allegro.ai/.../log + TRAINS results page: ... >>> logger.log_metrics({"val_loss": 1.23}, step=0) >>> logger.log_text("sample test") sample test @@ -69,7 +70,7 @@ class TrainsLogger(LightningLoggerBase): >>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8)) """ - _bypass = False + _bypass = None def __init__( self, @@ -83,8 +84,24 @@ class TrainsLogger(LightningLoggerBase): auto_resource_monitoring: bool = True ) -> None: super().__init__() - if self._bypass: + if self.bypass_mode(): self._trains = None + print('TRAINS Task: running in bypass mode') + print('TRAINS results page: disabled') + + class _TaskStub(object): + def __call__(self, *args, **kwargs): + return self + + def __getattr__(self, attr): + if attr in ('name', 'id'): + return '' + return self + + def __setattr__(self, attr, val): + pass + + self._trains = _TaskStub() else: self._trains = Task.init( project_name=project_name, @@ -114,8 +131,9 @@ class TrainsLogger(LightningLoggerBase): """ ID is a uuid (string) representing this specific experiment in the entire system. """ - if self._bypass or not self._trains: + if not self._trains: return None + return self._trains.id @rank_zero_only @@ -126,8 +144,8 @@ class TrainsLogger(LightningLoggerBase): params: The hyperparameters that passed through the model. """ - if self._bypass or not self._trains: - return None + if not self._trains: + return if not params: return @@ -147,8 +165,8 @@ class TrainsLogger(LightningLoggerBase): then the elements will be logged as "title" and "series" respectively. step: Step number at which the metrics should be recorded. Defaults to None. """ - if self._bypass or not self._trains: - return None + if not self._trains: + return if not step: step = self._trains.get_last_iteration() @@ -179,8 +197,8 @@ class TrainsLogger(LightningLoggerBase): value: The value to log. step: Step number at which the metrics should be recorded. Defaults to None. """ - if self._bypass or not self._trains: - return None + if not self._trains: + return if not step: step = self._trains.get_last_iteration() @@ -197,8 +215,12 @@ class TrainsLogger(LightningLoggerBase): Args: text: The value of the log (data-point). """ - if self._bypass or not self._trains: - return None + if self.bypass_mode(): + print(text) + return + + if not self._trains: + return self._trains.get_logger().report_text(text) @@ -222,8 +244,8 @@ class TrainsLogger(LightningLoggerBase): step: Step number at which the metrics should be recorded. Defaults to None. """ - if self._bypass or not self._trains: - return None + if not self._trains: + return if not step: step = self._trains.get_last_iteration() @@ -265,8 +287,8 @@ class TrainsLogger(LightningLoggerBase): If True local artifact will be deleted (only applies if artifact_object is a local file). Defaults to False. """ - if self._bypass or not self._trains: - return None + if not self._trains: + return self._trains.upload_artifact( name=name, artifact_object=artifact, metadata=metadata, @@ -278,8 +300,9 @@ class TrainsLogger(LightningLoggerBase): @rank_zero_only def finalize(self, status: str = None) -> None: - if self._bypass or not self._trains: - return None + if self.bypass_mode() or not self._trains: + return + self._trains.close() self._trains = None @@ -288,14 +311,16 @@ class TrainsLogger(LightningLoggerBase): """ Name is a human readable non-unique name (str) of the experiment. """ - if self._bypass or not self._trains: + if not self._trains: return '' + return self._trains.name @property def version(self) -> Union[str, None]: - if self._bypass or not self._trains: + if not self._trains: return None + return self._trains.id @classmethod @@ -327,9 +352,21 @@ class TrainsLogger(LightningLoggerBase): """ cls._bypass = bypass + @classmethod + def bypass_mode(cls) -> bool: + """ + bypass_mode returns the bypass mode state. + Notice GITHUB_ACTIONS env will automatically set bypass_mode to True + unless overridden specifically with set_bypass_mode(False) + + :return: If True, all outside communication is skipped + """ + return cls._bypass if cls._bypass is not None else bool(environ.get('CI')) + def __getstate__(self) -> Union[str, None]: - if self._bypass or not self._trains: + if self.bypass_mode() or not self._trains: return '' + return self._trains.id def __setstate__(self, state: str) -> None: