This commit is contained in:
Volodymyr Magamedov 2024-08-18 12:37:14 -05:00 committed by GitHub
commit f050de3b73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 23 deletions

View File

@ -63,7 +63,10 @@ _H2_TO_GRPC_STATUS_MAP = {
class Handler(AbstractHandler):
connection_lost = False
closing = False
def connection_made(self, connection: Any) -> None:
pass
def accept(self, stream: Any, headers: Any, release_stream: Any) -> None:
raise NotImplementedError('Client connection can not accept requests')
@ -72,7 +75,7 @@ class Handler(AbstractHandler):
pass
def close(self) -> None:
self.connection_lost = True
self.closing = True
class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]):
@ -737,7 +740,7 @@ class Channel:
@property
def _connected(self) -> bool:
return (self._protocol is not None
and not self._protocol.handler.connection_lost)
and not cast(Handler, self._protocol.handler).closing)
async def __connect__(self) -> H2Protocol:
if not self._connected:

View File

@ -488,6 +488,10 @@ class Stream:
class AbstractHandler(ABC):
@abstractmethod
def connection_made(self, connection: Connection) -> None:
pass
@abstractmethod
def accept(
self,
@ -709,6 +713,7 @@ class H2Protocol(Protocol):
self.connection.flush()
self.connection.initialize()
self.handler.connection_made(self.connection)
self.processor = EventsProcessor(self.handler, self.connection)
def data_received(self, data: bytes) -> None:

View File

@ -4,6 +4,7 @@ import socket
import logging
import asyncio
import warnings
from functools import partial
from types import TracebackType
from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast
@ -12,6 +13,7 @@ from contextlib import nullcontext
import h2.config
import h2.exceptions
from h2.errors import ErrorCodes
from multidict import MultiDict
@ -24,7 +26,7 @@ from .events import _DispatchServerEvents
from .metadata import Deadline, encode_grpc_message, _Metadata
from .metadata import encode_metadata, decode_metadata, _MetadataLike
from .metadata import _STATUS_DETAILS_KEY, encode_bin_value
from .protocol import H2Protocol, AbstractHandler
from .protocol import H2Protocol, AbstractHandler, Connection
from .exceptions import GRPCError, ProtocolError, StreamTerminatedError
from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase
from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec
@ -496,6 +498,7 @@ class _GC(abc.ABC):
class Handler(_GC, AbstractHandler):
__gc_interval__ = 10
connection: Connection
closing = False
def __init__(
@ -511,13 +514,17 @@ class Handler(_GC, AbstractHandler):
self.dispatch = dispatch
self.loop = asyncio.get_event_loop()
self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {}
self._cancelled: Set['asyncio.Task[None]'] = set()
def __gc_collect__(self) -> None:
self._tasks = {s: t for s, t in self._tasks.items()
if not t.done()}
self._cancelled = {t for t in self._cancelled
if not t.done()}
self._tasks = {s: t for s, t in self._tasks.items() if not t.done()}
def connection_made(self, connection: Connection) -> None:
self.connection = connection
def handler_done(self, stream: 'protocol.Stream', _: Any) -> None:
self._tasks.pop(stream, None)
if not self._tasks:
self.connection.close()
def accept(
self,
@ -525,30 +532,36 @@ class Handler(_GC, AbstractHandler):
headers: _Headers,
release_stream: Callable[[], Any],
) -> None:
self.__gc_step__()
self._tasks[stream] = self.loop.create_task(request_handler(
self.mapping, stream, headers, self.codec,
self.status_details_codec, self.dispatch, release_stream,
))
if self.closing:
stream.reset_nowait(ErrorCodes.REFUSED_STREAM)
release_stream()
else:
self.__gc_step__()
self._tasks[stream] = self.loop.create_task(request_handler(
self.mapping, stream, headers, self.codec,
self.status_details_codec, self.dispatch, release_stream,
))
def cancel(self, stream: 'protocol.Stream') -> None:
task = self._tasks.pop(stream)
task.cancel()
self._cancelled.add(task)
self._tasks[stream].cancel()
def close(self) -> None:
for task in self._tasks.values():
self.__gc_collect__()
for stream, task in self._tasks.items():
task.add_done_callback(partial(self.handler_done, stream))
task.cancel()
self._cancelled.update(self._tasks.values())
self.closing = True
async def wait_closed(self) -> None:
if self._cancelled:
await asyncio.wait(self._cancelled)
self.__gc_collect__()
if self._tasks:
await asyncio.wait(self._tasks.values())
else:
self.connection.close()
def check_closed(self) -> bool:
self.__gc_collect__()
return not self._tasks and not self._cancelled
return not self._tasks
class Server(_GC):
@ -737,11 +750,11 @@ class Server(_GC):
if self._server is None or self._server_closed_fut is None:
raise RuntimeError('Server is not started')
await self._server_closed_fut
await self._server.wait_closed()
if self._handlers:
await asyncio.wait({
self._loop.create_task(h.wait_closed()) for h in self._handlers
})
await self._server.wait_closed()
async def __aenter__(self) -> 'Server':
return self

View File

@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler):
headers = None
release_stream = None
def connection_made(self, connection):
pass
def accept(self, stream, headers, release_stream):
self.stream = stream
self.headers = headers