diff --git a/grpclib/_protocols.py b/grpclib/_protocols.py new file mode 100644 index 0000000..8da8a39 --- /dev/null +++ b/grpclib/_protocols.py @@ -0,0 +1,29 @@ +from typing import Mapping, Any +from typing_extensions import Protocol + +from . import const +from . import server +from . import events + + +class IServable(Protocol): + def __mapping__(self) -> Mapping[str, const.Handler]: ... + + +class IClosable(Protocol): + def close(self) -> None: ... + + +class IProtoMessage(Protocol): + @classmethod + def FromString(cls, s: bytes) -> 'IProtoMessage': ... + + def SerializeToString(self) -> bytes: ... + + +class IEventsTarget(Protocol): + __dispatch__: 'events._Dispatch' + + +class IServerMethodFunc(Protocol): + async def __call__(self, stream: 'server.Stream[Any, Any]') -> None: ... diff --git a/grpclib/encoding/base.py b/grpclib/encoding/base.py index 5acc22b..6e5b48d 100644 --- a/grpclib/encoding/base.py +++ b/grpclib/encoding/base.py @@ -1,5 +1,7 @@ import abc +from typing import Any + GRPC_CONTENT_TYPE = 'application/grpc' @@ -8,13 +10,13 @@ class CodecBase(abc.ABC): @property @abc.abstractmethod - def __content_subtype__(self): + def __content_subtype__(self) -> str: pass @abc.abstractmethod - def encode(self, message, message_type) -> bytes: + def encode(self, message: Any, message_type: Any) -> bytes: pass @abc.abstractmethod - def decode(self, data: bytes, message_type): + def decode(self, data: bytes, message_type: Any) -> Any: pass diff --git a/grpclib/encoding/proto.py b/grpclib/encoding/proto.py index 511b597..4c33b53 100644 --- a/grpclib/encoding/proto.py +++ b/grpclib/encoding/proto.py @@ -1,12 +1,26 @@ +from typing import Type, TYPE_CHECKING + from .base import CodecBase +if TYPE_CHECKING: + from .._protocols import IProtoMessage # noqa + + class ProtoCodec(CodecBase): __content_subtype__ = 'proto' - def encode(self, message, message_type): + def encode( + self, + message: 'IProtoMessage', + message_type: Type['IProtoMessage'], + ) -> bytes: assert isinstance(message, message_type), type(message) return message.SerializeToString() - def decode(self, data, message_type): + def decode( + self, + data: bytes, + message_type: Type['IProtoMessage'], + ) -> 'IProtoMessage': return message_type.FromString(data) diff --git a/grpclib/events.py b/grpclib/events.py index 316e585..30b8882 100644 --- a/grpclib/events.py +++ b/grpclib/events.py @@ -7,15 +7,7 @@ from .metadata import Deadline, _Metadata if TYPE_CHECKING: - from typing_extensions import Protocol - - from . import server - - class _Target(Protocol): - __dispatch__: '_Dispatch' - - class _MethodFunc(Protocol): - async def __call__(self, stream: 'server.Stream'): ... + from ._protocols import IEventsTarget, IServerMethodFunc # noqa class _Event: @@ -103,7 +95,7 @@ class _DispatchMeta(type): def listen( - target: '_Target', + target: 'IEventsTarget', event_type: Type[_EventType], callback: Callable[[_EventType], Coroutine], ): @@ -167,7 +159,7 @@ class RecvRequest(_Event, metaclass=_EventMeta): __payload__ = ('metadata', 'method_func') metadata: _Metadata - method_func: '_MethodFunc' + method_func: 'IServerMethodFunc' method_name: str deadline: Optional[Deadline] content_type: str diff --git a/grpclib/metadata.py b/grpclib/metadata.py index 398aa5a..abfad91 100644 --- a/grpclib/metadata.py +++ b/grpclib/metadata.py @@ -3,7 +3,7 @@ import time import platform from base64 import b64encode, b64decode -from typing import Union, Mapping, Sequence, Tuple, NewType +from typing import Union, Mapping, Tuple, NewType, Optional, cast, Collection from urllib.parse import quote, unquote from multidict import MultiDict @@ -33,8 +33,10 @@ _UNITS = { _TIMEOUT_RE = re.compile(r'^(\d+)([{}])$'.format(''.join(_UNITS))) +_Headers = Collection[Tuple[str, str]] -def decode_timeout(value): + +def decode_timeout(value: str) -> float: match = _TIMEOUT_RE.match(value) if match is None: raise ValueError('Invalid timeout: {}'.format(value)) @@ -56,23 +58,23 @@ def encode_timeout(timeout: float) -> str: class Deadline: """Represents request's deadline - fixed point in time """ - def __init__(self, *, _timestamp): + def __init__(self, *, _timestamp: float) -> None: self._timestamp = _timestamp - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if not isinstance(other, Deadline): raise TypeError('comparison is not supported between ' 'instances of \'{}\' and \'{}\'' .format(type(self).__name__, type(other).__name__)) return self._timestamp < other._timestamp - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Deadline): return False return self._timestamp == other._timestamp @classmethod - def from_headers(cls, headers): + def from_headers(cls, headers: _Headers) -> Optional['Deadline']: timeout = min(map(decode_timeout, (v for k, v in headers if k == 'grpc-timeout')), default=None) @@ -82,10 +84,10 @@ class Deadline: return None @classmethod - def from_timeout(cls, timeout): + def from_timeout(cls, timeout: float) -> 'Deadline': return cls(_timestamp=time.monotonic() + timeout) - def time_remaining(self): + def time_remaining(self) -> float: """Calculates remaining time for the current request completion This function returns time in seconds as a floating point number, @@ -117,11 +119,11 @@ _SPECIAL = { _Value = Union[str, bytes] _Metadata = NewType('_Metadata', 'MultiDict[_Value]') -_MetadataLike = Union[Mapping[str, _Value], Sequence[Tuple[str, _Value]]] +_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] -def decode_metadata(headers) -> _Metadata: - metadata = MultiDict() +def decode_metadata(headers: _Headers) -> _Metadata: + metadata = cast(_Metadata, MultiDict()) for key, value in headers: if key.startswith((':', 'grpc-')) or key in _SPECIAL: continue @@ -133,8 +135,8 @@ def decode_metadata(headers) -> _Metadata: return metadata -def encode_metadata(metadata: _MetadataLike): - if hasattr(metadata, 'items'): +def encode_metadata(metadata: _MetadataLike) -> _Headers: + if isinstance(metadata, Mapping): metadata = metadata.items() result = [] for key, value in metadata: diff --git a/grpclib/reflection/service.py b/grpclib/reflection/service.py index f87b9b5..354f1b7 100644 --- a/grpclib/reflection/service.py +++ b/grpclib/reflection/service.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import TYPE_CHECKING, Sequence, List +from typing import TYPE_CHECKING, Collection, List from google.protobuf import descriptor_pool from google.protobuf.descriptor_pb2 import FileDescriptorProto @@ -32,7 +32,7 @@ from .v1alpha.reflection_grpc import ( if TYPE_CHECKING: - from ..server import _Servable + from ..server import _Servable # noqa class _ServerReflection: @@ -147,7 +147,7 @@ class ServerReflection(_ServerReflection, ServerReflectionBase): Implements server reflection protocol. """ @classmethod - def extend(cls, services: 'Sequence[_Servable]') -> 'List[_Servable]': + def extend(cls, services: 'Collection[_Servable]') -> 'List[_Servable]': """ Extends services list with reflection service: diff --git a/grpclib/server.py b/grpclib/server.py index 69c1068..aae28bd 100644 --- a/grpclib/server.py +++ b/grpclib/server.py @@ -3,7 +3,7 @@ import socket import logging import asyncio -from typing import TYPE_CHECKING, Optional, Mapping, Sequence, Generic +from typing import TYPE_CHECKING, Optional, Collection, Generic import h2.config import h2.exceptions @@ -26,13 +26,7 @@ from .encoding.proto import ProtoCodec if TYPE_CHECKING: import ssl as _ssl # noqa - - from typing_extensions import Protocol - - from . import const - - class _Servable(Protocol): - def __mapping__(self) -> Mapping[str, const.Handler]: ... + from ._protocols import IServable # noqa log = logging.getLogger(__name__) @@ -475,7 +469,7 @@ class Server(_GC, asyncio.AbstractServer): def __init__( self, - handlers: Sequence['_Servable'], + handlers: Collection['IServable'], *, loop: Optional[asyncio.AbstractEventLoop] = None, codec: Optional[CodecBase] = None, diff --git a/grpclib/utils.py b/grpclib/utils.py index 814261f..f92de96 100644 --- a/grpclib/utils.py +++ b/grpclib/utils.py @@ -2,7 +2,9 @@ import sys import signal import asyncio -from typing import Optional, Iterable, TYPE_CHECKING, Sequence +from types import TracebackType +from typing import TYPE_CHECKING, Optional, Set, Type, ContextManager, List +from typing import Iterator, Collection from contextlib import contextmanager @@ -13,13 +15,12 @@ else: if TYPE_CHECKING: - from typing_extensions import Protocol - - class _Closable(Protocol): - def close(self) -> None: ... + from typing import Any # noqa + from .metadata import Deadline # noqa + from ._protocols import IServable, IClosable # noqa -class Wrapper: +class Wrapper(ContextManager[None]): """Special wrapper for coroutines to wake them up in case of some error. Example: @@ -36,14 +37,15 @@ class Wrapper: w.cancel(NoNeedToWaitError('With explanation')) """ - _error = None + _error: Optional[Exception] = None + _tasks: Set['asyncio.Task[Any]'] - cancelled = None + cancelled: Optional[bool] = None - def __init__(self): + def __init__(self) -> None: self._tasks = set() - def __enter__(self): + def __enter__(self) -> None: if self._error is not None: raise self._error @@ -53,14 +55,19 @@ class Wrapper: self._tasks.add(task) - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: task = _current_task() assert task self._tasks.discard(task) if self._error is not None: raise self._error - def cancel(self, error): + def cancel(self, error: Exception) -> None: self._error = error for task in self._tasks: task.cancel() @@ -88,13 +95,18 @@ class DeadlineWrapper(Wrapper): """ @contextmanager - def start(self, deadline, *, loop=None): + def start( + self, + deadline: 'Deadline', + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> Iterator['DeadlineWrapper']: loop = loop or asyncio.get_event_loop() timeout = deadline.time_remaining() if not timeout: raise asyncio.TimeoutError('Deadline exceeded') - def callback(): + def callback() -> None: self.cancel(asyncio.TimeoutError('Deadline exceeded')) timer = loop.call_later(timeout, callback) @@ -104,7 +116,7 @@ class DeadlineWrapper(Wrapper): timer.cancel() -def _service_name(service): +def _service_name(service: 'IServable') -> str: methods = service.__mapping__() method_name = next(iter(methods), None) assert method_name is not None @@ -112,7 +124,10 @@ def _service_name(service): return service_name -def _first_stage(sig_num, servers): +def _first_stage( + sig_num: 'signal.Signals', + servers: Collection['IClosable'], +) -> None: fail = False for server in servers: try: @@ -126,11 +141,15 @@ def _first_stage(sig_num, servers): _second_stage(sig_num) -def _second_stage(sig_num): +def _second_stage(sig_num: 'signal.Signals') -> None: raise SystemExit(128 + sig_num) -def _exit_handler(sig_num, servers, flag): +def _exit_handler( + sig_num: 'signal.Signals', + servers: Collection['IClosable'], + flag: List[bool], +) -> None: if flag: _second_stage(sig_num) else: @@ -140,11 +159,11 @@ def _exit_handler(sig_num, servers, flag): @contextmanager def graceful_exit( - servers: Sequence['_Closable'], + servers: Collection['IClosable'], *, loop: Optional[asyncio.AbstractEventLoop] = None, - signals: Iterable[int] = (signal.SIGINT, signal.SIGTERM), -): + signals: Collection[int] = (signal.SIGINT, signal.SIGTERM), +) -> Iterator[None]: """Utility context-manager to help properly shutdown server in response to the OS signals @@ -190,7 +209,7 @@ def graceful_exit( """ loop = loop or asyncio.get_event_loop() signals = set(signals) - flag = [] + flag: 'List[bool]' = [] for sig_num in signals: loop.add_signal_handler(sig_num, _exit_handler, sig_num, servers, flag) try: