diff --git a/tests/tests_contrib_logging.py b/tests/tests_contrib_logging.py index eaccb615..6f675dd7 100644 --- a/tests/tests_contrib_logging.py +++ b/tests/tests_contrib_logging.py @@ -10,7 +10,7 @@ from io import StringIO import pytest from tqdm import tqdm -from tqdm.contrib.logging import _get_first_found_console_logging_formatter +from tqdm.contrib.logging import _get_first_found_console_logging_handler from tqdm.contrib.logging import _TqdmLoggingHandler as TqdmLoggingHandler from tqdm.contrib.logging import logging_redirect_tqdm, tqdm_logging_redirect @@ -68,33 +68,25 @@ class TestTqdmLoggingHandler: logger.info('test') -class TestGetFirstFoundConsoleLoggingFormatter: +class TestGetFirstFoundConsoleLoggingHandler: def test_should_return_none_for_no_handlers(self): - assert _get_first_found_console_logging_formatter([]) is None + assert _get_first_found_console_logging_handler([]) is None def test_should_return_none_without_stream_handler(self): handler = logging.handlers.MemoryHandler(capacity=1) - handler.formatter = TEST_LOGGING_FORMATTER - assert _get_first_found_console_logging_formatter([handler]) is None + assert _get_first_found_console_logging_handler([handler]) is None def test_should_return_none_for_stream_handler_not_stdout_or_stderr(self): handler = logging.StreamHandler(StringIO()) - handler.formatter = TEST_LOGGING_FORMATTER - assert _get_first_found_console_logging_formatter([handler]) is None + assert _get_first_found_console_logging_handler([handler]) is None - def test_should_return_stream_handler_formatter_if_stream_is_stdout(self): + def test_should_return_stream_handler_if_stream_is_stdout(self): handler = logging.StreamHandler(sys.stdout) - handler.formatter = TEST_LOGGING_FORMATTER - assert _get_first_found_console_logging_formatter( - [handler] - ) == TEST_LOGGING_FORMATTER + assert _get_first_found_console_logging_handler([handler]) == handler - def test_should_return_stream_handler_formatter_if_stream_is_stderr(self): + def test_should_return_stream_handler_if_stream_is_stderr(self): handler = logging.StreamHandler(sys.stderr) - handler.formatter = TEST_LOGGING_FORMATTER - assert _get_first_found_console_logging_formatter( - [handler] - ) == TEST_LOGGING_FORMATTER + assert _get_first_found_console_logging_handler([handler]) == handler class TestRedirectLoggingToTqdm: diff --git a/tqdm/contrib/logging.py b/tqdm/contrib/logging.py index 5f70944d..cd9925ec 100644 --- a/tqdm/contrib/logging.py +++ b/tqdm/contrib/logging.py @@ -26,7 +26,7 @@ class _TqdmLoggingHandler(logging.StreamHandler): def emit(self, record): try: msg = self.format(record) - self.tqdm_class.write(msg) + self.tqdm_class.write(msg, file=self.stream) self.flush() except (KeyboardInterrupt, SystemExit): raise @@ -39,10 +39,10 @@ def _is_console_logging_handler(handler): and handler.stream in {sys.stdout, sys.stderr}) -def _get_first_found_console_logging_formatter(handlers): +def _get_first_found_console_logging_handler(handlers): for handler in handlers: if _is_console_logging_handler(handler): - return handler.formatter + return handler @contextmanager @@ -85,8 +85,10 @@ def logging_redirect_tqdm( try: for logger in loggers: tqdm_handler = _TqdmLoggingHandler(tqdm_class) - tqdm_handler.setFormatter( - _get_first_found_console_logging_formatter(logger.handlers)) + orig_handler = _get_first_found_console_logging_handler(logger.handlers) + if orig_handler is not None: + tqdm_handler.setFormatter(orig_handler.formatter) + tqdm_handler.stream = orig_handler.stream logger.handlers = [ handler for handler in logger.handlers if not _is_console_logging_handler(handler)] + [tqdm_handler]