Added more type annotations, moved protocols into dedicated private module

This commit is contained in:
Vladimir Magamedov 2019-05-31 14:50:51 +03:00
parent 7c88c771a2
commit 95ca81b4f5
8 changed files with 115 additions and 63 deletions

29
grpclib/_protocols.py Normal file
View File

@ -0,0 +1,29 @@
from typing import Mapping, Any
from typing_extensions import Protocol
from . import const
from . import server
from . import events
class IServable(Protocol):
def __mapping__(self) -> Mapping[str, const.Handler]: ...
class IClosable(Protocol):
def close(self) -> None: ...
class IProtoMessage(Protocol):
@classmethod
def FromString(cls, s: bytes) -> 'IProtoMessage': ...
def SerializeToString(self) -> bytes: ...
class IEventsTarget(Protocol):
__dispatch__: 'events._Dispatch'
class IServerMethodFunc(Protocol):
async def __call__(self, stream: 'server.Stream[Any, Any]') -> None: ...

View File

@ -1,5 +1,7 @@
import abc
from typing import Any
GRPC_CONTENT_TYPE = 'application/grpc'
@ -8,13 +10,13 @@ class CodecBase(abc.ABC):
@property
@abc.abstractmethod
def __content_subtype__(self):
def __content_subtype__(self) -> str:
pass
@abc.abstractmethod
def encode(self, message, message_type) -> bytes:
def encode(self, message: Any, message_type: Any) -> bytes:
pass
@abc.abstractmethod
def decode(self, data: bytes, message_type):
def decode(self, data: bytes, message_type: Any) -> Any:
pass

View File

@ -1,12 +1,26 @@
from typing import Type, TYPE_CHECKING
from .base import CodecBase
if TYPE_CHECKING:
from .._protocols import IProtoMessage # noqa
class ProtoCodec(CodecBase):
__content_subtype__ = 'proto'
def encode(self, message, message_type):
def encode(
self,
message: 'IProtoMessage',
message_type: Type['IProtoMessage'],
) -> bytes:
assert isinstance(message, message_type), type(message)
return message.SerializeToString()
def decode(self, data, message_type):
def decode(
self,
data: bytes,
message_type: Type['IProtoMessage'],
) -> 'IProtoMessage':
return message_type.FromString(data)

View File

@ -7,15 +7,7 @@ from .metadata import Deadline, _Metadata
if TYPE_CHECKING:
from typing_extensions import Protocol
from . import server
class _Target(Protocol):
__dispatch__: '_Dispatch'
class _MethodFunc(Protocol):
async def __call__(self, stream: 'server.Stream'): ...
from ._protocols import IEventsTarget, IServerMethodFunc # noqa
class _Event:
@ -103,7 +95,7 @@ class _DispatchMeta(type):
def listen(
target: '_Target',
target: 'IEventsTarget',
event_type: Type[_EventType],
callback: Callable[[_EventType], Coroutine],
):
@ -167,7 +159,7 @@ class RecvRequest(_Event, metaclass=_EventMeta):
__payload__ = ('metadata', 'method_func')
metadata: _Metadata
method_func: '_MethodFunc'
method_func: 'IServerMethodFunc'
method_name: str
deadline: Optional[Deadline]
content_type: str

View File

@ -3,7 +3,7 @@ import time
import platform
from base64 import b64encode, b64decode
from typing import Union, Mapping, Sequence, Tuple, NewType
from typing import Union, Mapping, Tuple, NewType, Optional, cast, Collection
from urllib.parse import quote, unquote
from multidict import MultiDict
@ -33,8 +33,10 @@ _UNITS = {
_TIMEOUT_RE = re.compile(r'^(\d+)([{}])$'.format(''.join(_UNITS)))
_Headers = Collection[Tuple[str, str]]
def decode_timeout(value):
def decode_timeout(value: str) -> float:
match = _TIMEOUT_RE.match(value)
if match is None:
raise ValueError('Invalid timeout: {}'.format(value))
@ -56,23 +58,23 @@ def encode_timeout(timeout: float) -> str:
class Deadline:
"""Represents request's deadline - fixed point in time
"""
def __init__(self, *, _timestamp):
def __init__(self, *, _timestamp: float) -> None:
self._timestamp = _timestamp
def __lt__(self, other):
def __lt__(self, other: object) -> bool:
if not isinstance(other, Deadline):
raise TypeError('comparison is not supported between '
'instances of \'{}\' and \'{}\''
.format(type(self).__name__, type(other).__name__))
return self._timestamp < other._timestamp
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Deadline):
return False
return self._timestamp == other._timestamp
@classmethod
def from_headers(cls, headers):
def from_headers(cls, headers: _Headers) -> Optional['Deadline']:
timeout = min(map(decode_timeout,
(v for k, v in headers if k == 'grpc-timeout')),
default=None)
@ -82,10 +84,10 @@ class Deadline:
return None
@classmethod
def from_timeout(cls, timeout):
def from_timeout(cls, timeout: float) -> 'Deadline':
return cls(_timestamp=time.monotonic() + timeout)
def time_remaining(self):
def time_remaining(self) -> float:
"""Calculates remaining time for the current request completion
This function returns time in seconds as a floating point number,
@ -117,11 +119,11 @@ _SPECIAL = {
_Value = Union[str, bytes]
_Metadata = NewType('_Metadata', 'MultiDict[_Value]')
_MetadataLike = Union[Mapping[str, _Value], Sequence[Tuple[str, _Value]]]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
def decode_metadata(headers) -> _Metadata:
metadata = MultiDict()
def decode_metadata(headers: _Headers) -> _Metadata:
metadata = cast(_Metadata, MultiDict())
for key, value in headers:
if key.startswith((':', 'grpc-')) or key in _SPECIAL:
continue
@ -133,8 +135,8 @@ def decode_metadata(headers) -> _Metadata:
return metadata
def encode_metadata(metadata: _MetadataLike):
if hasattr(metadata, 'items'):
def encode_metadata(metadata: _MetadataLike) -> _Headers:
if isinstance(metadata, Mapping):
metadata = metadata.items()
result = []
for key, value in metadata:

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Sequence, List
from typing import TYPE_CHECKING, Collection, List
from google.protobuf import descriptor_pool
from google.protobuf.descriptor_pb2 import FileDescriptorProto
@ -32,7 +32,7 @@ from .v1alpha.reflection_grpc import (
if TYPE_CHECKING:
from ..server import _Servable
from ..server import _Servable # noqa
class _ServerReflection:
@ -147,7 +147,7 @@ class ServerReflection(_ServerReflection, ServerReflectionBase):
Implements server reflection protocol.
"""
@classmethod
def extend(cls, services: 'Sequence[_Servable]') -> 'List[_Servable]':
def extend(cls, services: 'Collection[_Servable]') -> 'List[_Servable]':
"""
Extends services list with reflection service:

View File

@ -3,7 +3,7 @@ import socket
import logging
import asyncio
from typing import TYPE_CHECKING, Optional, Mapping, Sequence, Generic
from typing import TYPE_CHECKING, Optional, Collection, Generic
import h2.config
import h2.exceptions
@ -26,13 +26,7 @@ from .encoding.proto import ProtoCodec
if TYPE_CHECKING:
import ssl as _ssl # noqa
from typing_extensions import Protocol
from . import const
class _Servable(Protocol):
def __mapping__(self) -> Mapping[str, const.Handler]: ...
from ._protocols import IServable # noqa
log = logging.getLogger(__name__)
@ -475,7 +469,7 @@ class Server(_GC, asyncio.AbstractServer):
def __init__(
self,
handlers: Sequence['_Servable'],
handlers: Collection['IServable'],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
codec: Optional[CodecBase] = None,

View File

@ -2,7 +2,9 @@ import sys
import signal
import asyncio
from typing import Optional, Iterable, TYPE_CHECKING, Sequence
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Set, Type, ContextManager, List
from typing import Iterator, Collection
from contextlib import contextmanager
@ -13,13 +15,12 @@ else:
if TYPE_CHECKING:
from typing_extensions import Protocol
class _Closable(Protocol):
def close(self) -> None: ...
from typing import Any # noqa
from .metadata import Deadline # noqa
from ._protocols import IServable, IClosable # noqa
class Wrapper:
class Wrapper(ContextManager[None]):
"""Special wrapper for coroutines to wake them up in case of some error.
Example:
@ -36,14 +37,15 @@ class Wrapper:
w.cancel(NoNeedToWaitError('With explanation'))
"""
_error = None
_error: Optional[Exception] = None
_tasks: Set['asyncio.Task[Any]']
cancelled = None
cancelled: Optional[bool] = None
def __init__(self):
def __init__(self) -> None:
self._tasks = set()
def __enter__(self):
def __enter__(self) -> None:
if self._error is not None:
raise self._error
@ -53,14 +55,19 @@ class Wrapper:
self._tasks.add(task)
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
task = _current_task()
assert task
self._tasks.discard(task)
if self._error is not None:
raise self._error
def cancel(self, error):
def cancel(self, error: Exception) -> None:
self._error = error
for task in self._tasks:
task.cancel()
@ -88,13 +95,18 @@ class DeadlineWrapper(Wrapper):
"""
@contextmanager
def start(self, deadline, *, loop=None):
def start(
self,
deadline: 'Deadline',
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Iterator['DeadlineWrapper']:
loop = loop or asyncio.get_event_loop()
timeout = deadline.time_remaining()
if not timeout:
raise asyncio.TimeoutError('Deadline exceeded')
def callback():
def callback() -> None:
self.cancel(asyncio.TimeoutError('Deadline exceeded'))
timer = loop.call_later(timeout, callback)
@ -104,7 +116,7 @@ class DeadlineWrapper(Wrapper):
timer.cancel()
def _service_name(service):
def _service_name(service: 'IServable') -> str:
methods = service.__mapping__()
method_name = next(iter(methods), None)
assert method_name is not None
@ -112,7 +124,10 @@ def _service_name(service):
return service_name
def _first_stage(sig_num, servers):
def _first_stage(
sig_num: 'signal.Signals',
servers: Collection['IClosable'],
) -> None:
fail = False
for server in servers:
try:
@ -126,11 +141,15 @@ def _first_stage(sig_num, servers):
_second_stage(sig_num)
def _second_stage(sig_num):
def _second_stage(sig_num: 'signal.Signals') -> None:
raise SystemExit(128 + sig_num)
def _exit_handler(sig_num, servers, flag):
def _exit_handler(
sig_num: 'signal.Signals',
servers: Collection['IClosable'],
flag: List[bool],
) -> None:
if flag:
_second_stage(sig_num)
else:
@ -140,11 +159,11 @@ def _exit_handler(sig_num, servers, flag):
@contextmanager
def graceful_exit(
servers: Sequence['_Closable'],
servers: Collection['IClosable'],
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
signals: Iterable[int] = (signal.SIGINT, signal.SIGTERM),
):
signals: Collection[int] = (signal.SIGINT, signal.SIGTERM),
) -> Iterator[None]:
"""Utility context-manager to help properly shutdown server in response to
the OS signals
@ -190,7 +209,7 @@ def graceful_exit(
"""
loop = loop or asyncio.get_event_loop()
signals = set(signals)
flag = []
flag: 'List[bool]' = []
for sig_num in signals:
loop.add_signal_handler(sig_num, _exit_handler, sig_num, servers, flag)
try: