update exit to have proper typings

This commit is contained in:
Nathan Page 2021-05-09 15:11:42 -07:00
parent 037727b5a1
commit 2d22aa6882
4 changed files with 59 additions and 19 deletions

View File

@ -12,6 +12,7 @@ from functools import wraps
from getpass import getpass from getpass import getpass
from itertools import islice from itertools import islice
from time import monotonic from time import monotonic
from types import TracebackType
from typing import ( from typing import (
IO, IO,
TYPE_CHECKING, TYPE_CHECKING,
@ -24,6 +25,7 @@ from typing import (
NamedTuple, NamedTuple,
Optional, Optional,
TextIO, TextIO,
Type,
Union, Union,
cast, cast,
) )
@ -33,7 +35,6 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
from typing import Literal, Protocol, runtime_checkable # type: ignore from typing import Literal, Protocol, runtime_checkable # type: ignore
from . import errors, themes from . import errors, themes
from ._emoji_replace import _emoji_replace from ._emoji_replace import _emoji_replace
from ._log_render import FormatTimeCallable, LogRender from ._log_render import FormatTimeCallable, LogRender
@ -44,7 +45,7 @@ from .highlighter import NullHighlighter, ReprHighlighter
from .markup import render as render_markup from .markup import render as render_markup
from .measure import Measurement, measure_renderables from .measure import Measurement, measure_renderables
from .pager import Pager, SystemPager from .pager import Pager, SystemPager
from .pretty import is_expandable, Pretty from .pretty import Pretty, is_expandable
from .region import Region from .region import Region
from .scope import render_scope from .scope import render_scope
from .screen import Screen from .screen import Screen
@ -294,7 +295,12 @@ class Capture:
self._console.begin_capture() self._console.begin_capture()
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self._result = self._console.end_capture() self._result = self._console.end_capture()
def get(self) -> str: def get(self) -> str:
@ -318,7 +324,12 @@ class ThemeContext:
self.console.push_theme(self.theme) self.console.push_theme(self.theme)
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.console.pop_theme() self.console.pop_theme()
@ -341,7 +352,12 @@ class PagerContext:
self._console._enter_buffer() self._console._enter_buffer()
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc_type is None: if exc_type is None:
with self._console._lock: with self._console._lock:
buffer: List[Segment] = self._console._buffer[:] buffer: List[Segment] = self._console._buffer[:]
@ -391,7 +407,12 @@ class ScreenContext:
self.console.show_cursor(False) self.console.show_cursor(False)
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if self._changed: if self._changed:
self.console.set_alt_screen(False) self.console.set_alt_screen(False)
if self.hide_cursor: if self.hide_cursor:

View File

@ -1,6 +1,7 @@
import sys import sys
from threading import Event, RLock, Thread from threading import Event, RLock, Thread
from typing import IO, Any, Callable, List, Optional, TextIO, cast from types import TracebackType
from typing import IO, Any, Callable, List, Optional, TextIO, Type, cast
from . import get_console from . import get_console
from .console import Console, ConsoleRenderable, RenderableType, RenderHook from .console import Console, ConsoleRenderable, RenderableType, RenderHook
@ -163,7 +164,12 @@ class Live(JupyterMixin, RenderHook):
self.start(refresh=self._renderable is not None) self.start(refresh=self._renderable is not None)
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.stop() self.stop()
def _enable_redirect_io(self) -> None: def _enable_redirect_io(self) -> None:

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from math import ceil from math import ceil
from threading import Event, RLock, Thread from threading import Event, RLock, Thread
from types import TracebackType
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -18,19 +19,15 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
) )
from . import filesize, get_console from . import filesize, get_console
from .console import ( from .console import Console, JustifyMethod, RenderableType, RenderGroup
Console,
JustifyMethod,
RenderableType,
RenderGroup,
)
from .jupyter import JupyterMixin
from .highlighter import Highlighter from .highlighter import Highlighter
from .jupyter import JupyterMixin
from .live import Live from .live import Live
from .progress_bar import ProgressBar from .progress_bar import ProgressBar
from .spinner import Spinner from .spinner import Spinner
@ -75,7 +72,12 @@ class _TrackThread(Thread):
self.start() self.start()
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.done.set() self.done.set()
self.join() self.join()
@ -649,7 +651,12 @@ class Progress(JupyterMixin):
self.start() self.start()
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.stop() self.stop()
def track( def track(

View File

@ -1,4 +1,5 @@
from typing import Any, Optional from types import TracebackType
from typing import Optional, Type
from .console import Console, RenderableType from .console import Console, RenderableType
from .jupyter import JupyterMixin from .jupyter import JupyterMixin
@ -96,7 +97,12 @@ class Status(JupyterMixin):
self.start() self.start()
return self return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.stop() self.stop()