Merge 6cd97cad66
into 62f968a4c8
This commit is contained in:
commit
f050de3b73
|
@ -63,7 +63,10 @@ _H2_TO_GRPC_STATUS_MAP = {
|
|||
|
||||
|
||||
class Handler(AbstractHandler):
|
||||
connection_lost = False
|
||||
closing = False
|
||||
|
||||
def connection_made(self, connection: Any) -> None:
|
||||
pass
|
||||
|
||||
def accept(self, stream: Any, headers: Any, release_stream: Any) -> None:
|
||||
raise NotImplementedError('Client connection can not accept requests')
|
||||
|
@ -72,7 +75,7 @@ class Handler(AbstractHandler):
|
|||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
self.connection_lost = True
|
||||
self.closing = True
|
||||
|
||||
|
||||
class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]):
|
||||
|
@ -737,7 +740,7 @@ class Channel:
|
|||
@property
|
||||
def _connected(self) -> bool:
|
||||
return (self._protocol is not None
|
||||
and not self._protocol.handler.connection_lost)
|
||||
and not cast(Handler, self._protocol.handler).closing)
|
||||
|
||||
async def __connect__(self) -> H2Protocol:
|
||||
if not self._connected:
|
||||
|
|
|
@ -488,6 +488,10 @@ class Stream:
|
|||
|
||||
class AbstractHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def connection_made(self, connection: Connection) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def accept(
|
||||
self,
|
||||
|
@ -709,6 +713,7 @@ class H2Protocol(Protocol):
|
|||
self.connection.flush()
|
||||
self.connection.initialize()
|
||||
|
||||
self.handler.connection_made(self.connection)
|
||||
self.processor = EventsProcessor(self.handler, self.connection)
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
|
|
|
@ -4,6 +4,7 @@ import socket
|
|||
import logging
|
||||
import asyncio
|
||||
import warnings
|
||||
from functools import partial
|
||||
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast
|
||||
|
@ -12,6 +13,7 @@ from contextlib import nullcontext
|
|||
|
||||
import h2.config
|
||||
import h2.exceptions
|
||||
from h2.errors import ErrorCodes
|
||||
|
||||
from multidict import MultiDict
|
||||
|
||||
|
@ -24,7 +26,7 @@ from .events import _DispatchServerEvents
|
|||
from .metadata import Deadline, encode_grpc_message, _Metadata
|
||||
from .metadata import encode_metadata, decode_metadata, _MetadataLike
|
||||
from .metadata import _STATUS_DETAILS_KEY, encode_bin_value
|
||||
from .protocol import H2Protocol, AbstractHandler
|
||||
from .protocol import H2Protocol, AbstractHandler, Connection
|
||||
from .exceptions import GRPCError, ProtocolError, StreamTerminatedError
|
||||
from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase
|
||||
from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec
|
||||
|
@ -496,6 +498,7 @@ class _GC(abc.ABC):
|
|||
class Handler(_GC, AbstractHandler):
|
||||
__gc_interval__ = 10
|
||||
|
||||
connection: Connection
|
||||
closing = False
|
||||
|
||||
def __init__(
|
||||
|
@ -511,13 +514,17 @@ class Handler(_GC, AbstractHandler):
|
|||
self.dispatch = dispatch
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {}
|
||||
self._cancelled: Set['asyncio.Task[None]'] = set()
|
||||
|
||||
def __gc_collect__(self) -> None:
|
||||
self._tasks = {s: t for s, t in self._tasks.items()
|
||||
if not t.done()}
|
||||
self._cancelled = {t for t in self._cancelled
|
||||
if not t.done()}
|
||||
self._tasks = {s: t for s, t in self._tasks.items() if not t.done()}
|
||||
|
||||
def connection_made(self, connection: Connection) -> None:
|
||||
self.connection = connection
|
||||
|
||||
def handler_done(self, stream: 'protocol.Stream', _: Any) -> None:
|
||||
self._tasks.pop(stream, None)
|
||||
if not self._tasks:
|
||||
self.connection.close()
|
||||
|
||||
def accept(
|
||||
self,
|
||||
|
@ -525,30 +532,36 @@ class Handler(_GC, AbstractHandler):
|
|||
headers: _Headers,
|
||||
release_stream: Callable[[], Any],
|
||||
) -> None:
|
||||
self.__gc_step__()
|
||||
self._tasks[stream] = self.loop.create_task(request_handler(
|
||||
self.mapping, stream, headers, self.codec,
|
||||
self.status_details_codec, self.dispatch, release_stream,
|
||||
))
|
||||
if self.closing:
|
||||
stream.reset_nowait(ErrorCodes.REFUSED_STREAM)
|
||||
release_stream()
|
||||
else:
|
||||
self.__gc_step__()
|
||||
self._tasks[stream] = self.loop.create_task(request_handler(
|
||||
self.mapping, stream, headers, self.codec,
|
||||
self.status_details_codec, self.dispatch, release_stream,
|
||||
))
|
||||
|
||||
def cancel(self, stream: 'protocol.Stream') -> None:
|
||||
task = self._tasks.pop(stream)
|
||||
task.cancel()
|
||||
self._cancelled.add(task)
|
||||
self._tasks[stream].cancel()
|
||||
|
||||
def close(self) -> None:
|
||||
for task in self._tasks.values():
|
||||
self.__gc_collect__()
|
||||
for stream, task in self._tasks.items():
|
||||
task.add_done_callback(partial(self.handler_done, stream))
|
||||
task.cancel()
|
||||
self._cancelled.update(self._tasks.values())
|
||||
self.closing = True
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
if self._cancelled:
|
||||
await asyncio.wait(self._cancelled)
|
||||
self.__gc_collect__()
|
||||
if self._tasks:
|
||||
await asyncio.wait(self._tasks.values())
|
||||
else:
|
||||
self.connection.close()
|
||||
|
||||
def check_closed(self) -> bool:
|
||||
self.__gc_collect__()
|
||||
return not self._tasks and not self._cancelled
|
||||
return not self._tasks
|
||||
|
||||
|
||||
class Server(_GC):
|
||||
|
@ -737,11 +750,11 @@ class Server(_GC):
|
|||
if self._server is None or self._server_closed_fut is None:
|
||||
raise RuntimeError('Server is not started')
|
||||
await self._server_closed_fut
|
||||
await self._server.wait_closed()
|
||||
if self._handlers:
|
||||
await asyncio.wait({
|
||||
self._loop.create_task(h.wait_closed()) for h in self._handlers
|
||||
})
|
||||
await self._server.wait_closed()
|
||||
|
||||
async def __aenter__(self) -> 'Server':
|
||||
return self
|
||||
|
|
|
@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler):
|
|||
headers = None
|
||||
release_stream = None
|
||||
|
||||
def connection_made(self, connection):
|
||||
pass
|
||||
|
||||
def accept(self, stream, headers, release_stream):
|
||||
self.stream = stream
|
||||
self.headers = headers
|
||||
|
|
Loading…
Reference in New Issue