From 2d22aa6882544db6427dccd3e733e93e01366b81 Mon Sep 17 00:00:00 2001 From: Nathan Page Date: Sun, 9 May 2021 15:11:42 -0700 Subject: [PATCH] update exit to have proper typings --- rich/console.py | 33 +++++++++++++++++++++++++++------ rich/live.py | 10 ++++++++-- rich/progress.py | 25 ++++++++++++++++--------- rich/status.py | 10 ++++++++-- 4 files changed, 59 insertions(+), 19 deletions(-) diff --git a/rich/console.py b/rich/console.py index ec1093e9..94752e10 100644 --- a/rich/console.py +++ b/rich/console.py @@ -12,6 +12,7 @@ from functools import wraps from getpass import getpass from itertools import islice from time import monotonic +from types import TracebackType from typing import ( IO, TYPE_CHECKING, @@ -24,6 +25,7 @@ from typing import ( NamedTuple, Optional, TextIO, + Type, Union, cast, ) @@ -33,7 +35,6 @@ try: except ImportError: # pragma: no cover from typing import Literal, Protocol, runtime_checkable # type: ignore - from . import errors, themes from ._emoji_replace import _emoji_replace from ._log_render import FormatTimeCallable, LogRender @@ -44,7 +45,7 @@ from .highlighter import NullHighlighter, ReprHighlighter from .markup import render as render_markup from .measure import Measurement, measure_renderables from .pager import Pager, SystemPager -from .pretty import is_expandable, Pretty +from .pretty import Pretty, is_expandable from .region import Region from .scope import render_scope from .screen import Screen @@ -294,7 +295,12 @@ class Capture: self._console.begin_capture() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self._result = self._console.end_capture() def get(self) -> str: @@ -318,7 +324,12 @@ class ThemeContext: self.console.push_theme(self.theme) return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.console.pop_theme() @@ -341,7 +352,12 @@ class PagerContext: self._console._enter_buffer() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: if exc_type is None: with self._console._lock: buffer: List[Segment] = self._console._buffer[:] @@ -391,7 +407,12 @@ class ScreenContext: self.console.show_cursor(False) return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: if self._changed: self.console.set_alt_screen(False) if self.hide_cursor: diff --git a/rich/live.py b/rich/live.py index e7a0f9f5..f2ba64a7 100644 --- a/rich/live.py +++ b/rich/live.py @@ -1,6 +1,7 @@ import sys from threading import Event, RLock, Thread -from typing import IO, Any, Callable, List, Optional, TextIO, cast +from types import TracebackType +from typing import IO, Any, Callable, List, Optional, TextIO, Type, cast from . import get_console from .console import Console, ConsoleRenderable, RenderableType, RenderHook @@ -163,7 +164,12 @@ class Live(JupyterMixin, RenderHook): self.start(refresh=self._renderable is not None) return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.stop() def _enable_redirect_io(self) -> None: diff --git a/rich/progress.py b/rich/progress.py index f4484824..e67173d3 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from datetime import timedelta from math import ceil from threading import Event, RLock, Thread +from types import TracebackType from typing import ( Any, Callable, @@ -18,19 +19,15 @@ from typing import ( Optional, Sequence, Tuple, + Type, TypeVar, Union, ) from . import filesize, get_console -from .console import ( - Console, - JustifyMethod, - RenderableType, - RenderGroup, -) -from .jupyter import JupyterMixin +from .console import Console, JustifyMethod, RenderableType, RenderGroup from .highlighter import Highlighter +from .jupyter import JupyterMixin from .live import Live from .progress_bar import ProgressBar from .spinner import Spinner @@ -75,7 +72,12 @@ class _TrackThread(Thread): self.start() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.done.set() self.join() @@ -649,7 +651,12 @@ class Progress(JupyterMixin): self.start() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.stop() def track( diff --git a/rich/status.py b/rich/status.py index 92c273f9..09eff405 100644 --- a/rich/status.py +++ b/rich/status.py @@ -1,4 +1,5 @@ -from typing import Any, Optional +from types import TracebackType +from typing import Optional, Type from .console import Console, RenderableType from .jupyter import JupyterMixin @@ -96,7 +97,12 @@ class Status(JupyterMixin): self.start() return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.stop()