Added more type annotations, moved protocols into dedicated private module
This commit is contained in:
parent
7c88c771a2
commit
95ca81b4f5
|
@ -0,0 +1,29 @@
|
|||
from typing import Mapping, Any
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from . import const
|
||||
from . import server
|
||||
from . import events
|
||||
|
||||
|
||||
class IServable(Protocol):
|
||||
def __mapping__(self) -> Mapping[str, const.Handler]: ...
|
||||
|
||||
|
||||
class IClosable(Protocol):
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
class IProtoMessage(Protocol):
|
||||
@classmethod
|
||||
def FromString(cls, s: bytes) -> 'IProtoMessage': ...
|
||||
|
||||
def SerializeToString(self) -> bytes: ...
|
||||
|
||||
|
||||
class IEventsTarget(Protocol):
|
||||
__dispatch__: 'events._Dispatch'
|
||||
|
||||
|
||||
class IServerMethodFunc(Protocol):
|
||||
async def __call__(self, stream: 'server.Stream[Any, Any]') -> None: ...
|
|
@ -1,5 +1,7 @@
|
|||
import abc
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
GRPC_CONTENT_TYPE = 'application/grpc'
|
||||
|
||||
|
@ -8,13 +10,13 @@ class CodecBase(abc.ABC):
|
|||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def __content_subtype__(self):
|
||||
def __content_subtype__(self) -> str:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def encode(self, message, message_type) -> bytes:
|
||||
def encode(self, message: Any, message_type: Any) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def decode(self, data: bytes, message_type):
|
||||
def decode(self, data: bytes, message_type: Any) -> Any:
|
||||
pass
|
||||
|
|
|
@ -1,12 +1,26 @@
|
|||
from typing import Type, TYPE_CHECKING
|
||||
|
||||
from .base import CodecBase
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._protocols import IProtoMessage # noqa
|
||||
|
||||
|
||||
class ProtoCodec(CodecBase):
|
||||
__content_subtype__ = 'proto'
|
||||
|
||||
def encode(self, message, message_type):
|
||||
def encode(
|
||||
self,
|
||||
message: 'IProtoMessage',
|
||||
message_type: Type['IProtoMessage'],
|
||||
) -> bytes:
|
||||
assert isinstance(message, message_type), type(message)
|
||||
return message.SerializeToString()
|
||||
|
||||
def decode(self, data, message_type):
|
||||
def decode(
|
||||
self,
|
||||
data: bytes,
|
||||
message_type: Type['IProtoMessage'],
|
||||
) -> 'IProtoMessage':
|
||||
return message_type.FromString(data)
|
||||
|
|
|
@ -7,15 +7,7 @@ from .metadata import Deadline, _Metadata
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from . import server
|
||||
|
||||
class _Target(Protocol):
|
||||
__dispatch__: '_Dispatch'
|
||||
|
||||
class _MethodFunc(Protocol):
|
||||
async def __call__(self, stream: 'server.Stream'): ...
|
||||
from ._protocols import IEventsTarget, IServerMethodFunc # noqa
|
||||
|
||||
|
||||
class _Event:
|
||||
|
@ -103,7 +95,7 @@ class _DispatchMeta(type):
|
|||
|
||||
|
||||
def listen(
|
||||
target: '_Target',
|
||||
target: 'IEventsTarget',
|
||||
event_type: Type[_EventType],
|
||||
callback: Callable[[_EventType], Coroutine],
|
||||
):
|
||||
|
@ -167,7 +159,7 @@ class RecvRequest(_Event, metaclass=_EventMeta):
|
|||
__payload__ = ('metadata', 'method_func')
|
||||
|
||||
metadata: _Metadata
|
||||
method_func: '_MethodFunc'
|
||||
method_func: 'IServerMethodFunc'
|
||||
method_name: str
|
||||
deadline: Optional[Deadline]
|
||||
content_type: str
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
import platform
|
||||
|
||||
from base64 import b64encode, b64decode
|
||||
from typing import Union, Mapping, Sequence, Tuple, NewType
|
||||
from typing import Union, Mapping, Tuple, NewType, Optional, cast, Collection
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
from multidict import MultiDict
|
||||
|
@ -33,8 +33,10 @@ _UNITS = {
|
|||
|
||||
_TIMEOUT_RE = re.compile(r'^(\d+)([{}])$'.format(''.join(_UNITS)))
|
||||
|
||||
_Headers = Collection[Tuple[str, str]]
|
||||
|
||||
def decode_timeout(value):
|
||||
|
||||
def decode_timeout(value: str) -> float:
|
||||
match = _TIMEOUT_RE.match(value)
|
||||
if match is None:
|
||||
raise ValueError('Invalid timeout: {}'.format(value))
|
||||
|
@ -56,23 +58,23 @@ def encode_timeout(timeout: float) -> str:
|
|||
class Deadline:
|
||||
"""Represents request's deadline - fixed point in time
|
||||
"""
|
||||
def __init__(self, *, _timestamp):
|
||||
def __init__(self, *, _timestamp: float) -> None:
|
||||
self._timestamp = _timestamp
|
||||
|
||||
def __lt__(self, other):
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, Deadline):
|
||||
raise TypeError('comparison is not supported between '
|
||||
'instances of \'{}\' and \'{}\''
|
||||
.format(type(self).__name__, type(other).__name__))
|
||||
return self._timestamp < other._timestamp
|
||||
|
||||
def __eq__(self, other):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Deadline):
|
||||
return False
|
||||
return self._timestamp == other._timestamp
|
||||
|
||||
@classmethod
|
||||
def from_headers(cls, headers):
|
||||
def from_headers(cls, headers: _Headers) -> Optional['Deadline']:
|
||||
timeout = min(map(decode_timeout,
|
||||
(v for k, v in headers if k == 'grpc-timeout')),
|
||||
default=None)
|
||||
|
@ -82,10 +84,10 @@ class Deadline:
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def from_timeout(cls, timeout):
|
||||
def from_timeout(cls, timeout: float) -> 'Deadline':
|
||||
return cls(_timestamp=time.monotonic() + timeout)
|
||||
|
||||
def time_remaining(self):
|
||||
def time_remaining(self) -> float:
|
||||
"""Calculates remaining time for the current request completion
|
||||
|
||||
This function returns time in seconds as a floating point number,
|
||||
|
@ -117,11 +119,11 @@ _SPECIAL = {
|
|||
|
||||
_Value = Union[str, bytes]
|
||||
_Metadata = NewType('_Metadata', 'MultiDict[_Value]')
|
||||
_MetadataLike = Union[Mapping[str, _Value], Sequence[Tuple[str, _Value]]]
|
||||
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
|
||||
|
||||
|
||||
def decode_metadata(headers) -> _Metadata:
|
||||
metadata = MultiDict()
|
||||
def decode_metadata(headers: _Headers) -> _Metadata:
|
||||
metadata = cast(_Metadata, MultiDict())
|
||||
for key, value in headers:
|
||||
if key.startswith((':', 'grpc-')) or key in _SPECIAL:
|
||||
continue
|
||||
|
@ -133,8 +135,8 @@ def decode_metadata(headers) -> _Metadata:
|
|||
return metadata
|
||||
|
||||
|
||||
def encode_metadata(metadata: _MetadataLike):
|
||||
if hasattr(metadata, 'items'):
|
||||
def encode_metadata(metadata: _MetadataLike) -> _Headers:
|
||||
if isinstance(metadata, Mapping):
|
||||
metadata = metadata.items()
|
||||
result = []
|
||||
for key, value in metadata:
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import TYPE_CHECKING, Sequence, List
|
||||
from typing import TYPE_CHECKING, Collection, List
|
||||
|
||||
from google.protobuf import descriptor_pool
|
||||
from google.protobuf.descriptor_pb2 import FileDescriptorProto
|
||||
|
@ -32,7 +32,7 @@ from .v1alpha.reflection_grpc import (
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..server import _Servable
|
||||
from ..server import _Servable # noqa
|
||||
|
||||
|
||||
class _ServerReflection:
|
||||
|
@ -147,7 +147,7 @@ class ServerReflection(_ServerReflection, ServerReflectionBase):
|
|||
Implements server reflection protocol.
|
||||
"""
|
||||
@classmethod
|
||||
def extend(cls, services: 'Sequence[_Servable]') -> 'List[_Servable]':
|
||||
def extend(cls, services: 'Collection[_Servable]') -> 'List[_Servable]':
|
||||
"""
|
||||
Extends services list with reflection service:
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import socket
|
|||
import logging
|
||||
import asyncio
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Mapping, Sequence, Generic
|
||||
from typing import TYPE_CHECKING, Optional, Collection, Generic
|
||||
|
||||
import h2.config
|
||||
import h2.exceptions
|
||||
|
@ -26,13 +26,7 @@ from .encoding.proto import ProtoCodec
|
|||
|
||||
if TYPE_CHECKING:
|
||||
import ssl as _ssl # noqa
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from . import const
|
||||
|
||||
class _Servable(Protocol):
|
||||
def __mapping__(self) -> Mapping[str, const.Handler]: ...
|
||||
from ._protocols import IServable # noqa
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -475,7 +469,7 @@ class Server(_GC, asyncio.AbstractServer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
handlers: Sequence['_Servable'],
|
||||
handlers: Collection['IServable'],
|
||||
*,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
codec: Optional[CodecBase] = None,
|
||||
|
|
|
@ -2,7 +2,9 @@ import sys
|
|||
import signal
|
||||
import asyncio
|
||||
|
||||
from typing import Optional, Iterable, TYPE_CHECKING, Sequence
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Optional, Set, Type, ContextManager, List
|
||||
from typing import Iterator, Collection
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
|
@ -13,13 +15,12 @@ else:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
class _Closable(Protocol):
|
||||
def close(self) -> None: ...
|
||||
from typing import Any # noqa
|
||||
from .metadata import Deadline # noqa
|
||||
from ._protocols import IServable, IClosable # noqa
|
||||
|
||||
|
||||
class Wrapper:
|
||||
class Wrapper(ContextManager[None]):
|
||||
"""Special wrapper for coroutines to wake them up in case of some error.
|
||||
|
||||
Example:
|
||||
|
@ -36,14 +37,15 @@ class Wrapper:
|
|||
w.cancel(NoNeedToWaitError('With explanation'))
|
||||
|
||||
"""
|
||||
_error = None
|
||||
_error: Optional[Exception] = None
|
||||
_tasks: Set['asyncio.Task[Any]']
|
||||
|
||||
cancelled = None
|
||||
cancelled: Optional[bool] = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._tasks = set()
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
|
||||
|
@ -53,14 +55,19 @@ class Wrapper:
|
|||
|
||||
self._tasks.add(task)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
task = _current_task()
|
||||
assert task
|
||||
self._tasks.discard(task)
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
|
||||
def cancel(self, error):
|
||||
def cancel(self, error: Exception) -> None:
|
||||
self._error = error
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
|
@ -88,13 +95,18 @@ class DeadlineWrapper(Wrapper):
|
|||
|
||||
"""
|
||||
@contextmanager
|
||||
def start(self, deadline, *, loop=None):
|
||||
def start(
|
||||
self,
|
||||
deadline: 'Deadline',
|
||||
*,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
) -> Iterator['DeadlineWrapper']:
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
timeout = deadline.time_remaining()
|
||||
if not timeout:
|
||||
raise asyncio.TimeoutError('Deadline exceeded')
|
||||
|
||||
def callback():
|
||||
def callback() -> None:
|
||||
self.cancel(asyncio.TimeoutError('Deadline exceeded'))
|
||||
|
||||
timer = loop.call_later(timeout, callback)
|
||||
|
@ -104,7 +116,7 @@ class DeadlineWrapper(Wrapper):
|
|||
timer.cancel()
|
||||
|
||||
|
||||
def _service_name(service):
|
||||
def _service_name(service: 'IServable') -> str:
|
||||
methods = service.__mapping__()
|
||||
method_name = next(iter(methods), None)
|
||||
assert method_name is not None
|
||||
|
@ -112,7 +124,10 @@ def _service_name(service):
|
|||
return service_name
|
||||
|
||||
|
||||
def _first_stage(sig_num, servers):
|
||||
def _first_stage(
|
||||
sig_num: 'signal.Signals',
|
||||
servers: Collection['IClosable'],
|
||||
) -> None:
|
||||
fail = False
|
||||
for server in servers:
|
||||
try:
|
||||
|
@ -126,11 +141,15 @@ def _first_stage(sig_num, servers):
|
|||
_second_stage(sig_num)
|
||||
|
||||
|
||||
def _second_stage(sig_num):
|
||||
def _second_stage(sig_num: 'signal.Signals') -> None:
|
||||
raise SystemExit(128 + sig_num)
|
||||
|
||||
|
||||
def _exit_handler(sig_num, servers, flag):
|
||||
def _exit_handler(
|
||||
sig_num: 'signal.Signals',
|
||||
servers: Collection['IClosable'],
|
||||
flag: List[bool],
|
||||
) -> None:
|
||||
if flag:
|
||||
_second_stage(sig_num)
|
||||
else:
|
||||
|
@ -140,11 +159,11 @@ def _exit_handler(sig_num, servers, flag):
|
|||
|
||||
@contextmanager
|
||||
def graceful_exit(
|
||||
servers: Sequence['_Closable'],
|
||||
servers: Collection['IClosable'],
|
||||
*,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
signals: Iterable[int] = (signal.SIGINT, signal.SIGTERM),
|
||||
):
|
||||
signals: Collection[int] = (signal.SIGINT, signal.SIGTERM),
|
||||
) -> Iterator[None]:
|
||||
"""Utility context-manager to help properly shutdown server in response to
|
||||
the OS signals
|
||||
|
||||
|
@ -190,7 +209,7 @@ def graceful_exit(
|
|||
"""
|
||||
loop = loop or asyncio.get_event_loop()
|
||||
signals = set(signals)
|
||||
flag = []
|
||||
flag: 'List[bool]' = []
|
||||
for sig_num in signals:
|
||||
loop.add_signal_handler(sig_num, _exit_handler, sig_num, servers, flag)
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue