Merge pull request #5604 from mhils/mypy-and-closing

Re-add `is_closing` check, fixup mypy.
This commit is contained in:
Maximilian Hils 2022-09-22 19:00:07 +02:00 committed by GitHub
commit 810401f454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 41 additions and 24 deletions

View File

@ -314,7 +314,7 @@ class Dumper:
def format_websocket_error(self, websocket: WebSocketData) -> str:
try:
ret = CloseReason(websocket.close_code).name
ret = CloseReason(websocket.close_code).name # type: ignore
except ValueError:
ret = f"UNKNOWN_ERROR={websocket.close_code}"
if websocket.close_reason:
@ -362,8 +362,8 @@ class Dumper:
desc = f"DNS {opcode} ({type})"
desc_color = {
"DNS QUERY (A)": "green",
"DNS QUERY (AAAA)": "magenta",
"A": "green",
"AAAA": "magenta",
}.get(type, "red")
desc = self.style(desc, fg=desc_color)

View File

@ -5,6 +5,7 @@ import functools
import inspect
import logging
import pyparsing
import sys
import textwrap
import types
@ -194,7 +195,7 @@ class CommandManager:
Parse a possibly partial command. Return a sequence of ParseResults and a sequence of remainder type help items.
"""
parts: list[str] = command_lexer.expr.parseString(cmdstr, parseAll=True)
parts: pyparsing.ParseResults = command_lexer.expr.parseString(cmdstr, parseAll=True)
parsed: list[ParseResult] = []
next_params: list[CommandParameter] = [

View File

@ -175,8 +175,12 @@ class DatagramWriter:
"""
self._transport = transport
self._remote_addr = remote_addr
self._reader = reader
self._closed = asyncio.Event() if reader is not None else None
if reader is not None:
self._reader = reader
self._closed = asyncio.Event()
else:
self._reader = None
self._closed = None
@property
def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport:
@ -199,9 +203,15 @@ class DatagramWriter:
self._transport.close()
else:
self._closed.set()
if self._reader is not None:
assert self._reader
self._reader.feed_eof()
def is_closing(self) -> bool:
if self._closed is None:
return self._transport.is_closing()
else:
return self._closed.is_set()
async def wait_closed(self) -> None:
if self._closed is None:
await self._protocol.wait_closed()

View File

@ -141,7 +141,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
def make_pipe(self) -> layer.CommandGenerator[None]:
self.state = self.passthrough
if self.buf:
already_received = self.buf.maybe_extract_at_most(len(self.buf))
already_received = self.buf.maybe_extract_at_most(len(self.buf)) or b""
# Some clients send superfluous newlines after CONNECT, we want to eat those.
already_received = already_received.lstrip(b"\r\n")
if already_received:
@ -264,11 +264,10 @@ class Http1Server(Http1Connection):
if isinstance(event, events.DataReceived):
request_head = self.buf.maybe_extract_lines()
if request_head:
request_head = [
bytes(x) for x in request_head
] # TODO: Make url.parse compatible with bytearrays
try:
self.request = http1.read_request_head(request_head)
self.request = http1.read_request_head(
[bytes(x) for x in request_head]
)
if self.context.options.validate_inbound_headers:
http1.validate_headers(self.request.headers)
expected_body_size = http1.expected_http_body_size(self.request)
@ -388,11 +387,10 @@ class Http1Client(Http1Connection):
response_head = self.buf.maybe_extract_lines()
if response_head:
response_head = [
bytes(x) for x in response_head
] # TODO: Make url.parse compatible with bytearrays
try:
self.response = http1.read_response_head(response_head)
self.response = http1.read_response_head(
[bytes(x) for x in response_head]
)
if self.context.options.validate_inbound_headers:
http1.validate_headers(self.response.headers)
expected_size = http1.expected_http_body_size(

View File

@ -75,11 +75,10 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
self.buf += data
response_head = self.buf.maybe_extract_lines()
if response_head:
response_head = [
bytes(x) for x in response_head
] # TODO: Make url.parse compatible with bytearrays
try:
response = http1.read_response_head(response_head)
response = http1.read_response_head([
bytes(x) for x in response_head
])
except ValueError as e:
proxyaddr = human.format_address(self.tunnel_connection.address)
yield commands.Log(f"{proxyaddr}: {e}")

View File

@ -144,10 +144,10 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
)
handler.layer = self.make_top_layer(handler.layer.context)
if isinstance(self.mode, mode_specs.TransparentMode):
socket = writer.get_extra_info("socket")
s = cast(socket.socket, writer.get_extra_info("socket"))
try:
assert platform.original_addr
original_dst = platform.original_addr(socket)
original_dst = platform.original_addr(s)
except Exception as e:
logger.error(f"Transparent mode failure: {e!r}")
return
@ -390,6 +390,7 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
await self.handle_tcp_connection(stream, stream)
def wg_handle_udp_datagram(self, data: bytes, remote_addr: Address, local_addr: Address) -> None:
assert self._server is not None
transport = WireGuardDatagramTransport(self._server, local_addr, remote_addr)
self.handle_udp_datagram(
transport,

View File

@ -367,7 +367,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
elif isinstance(command, commands.SendData):
writer = self.transports[command.connection].writer
assert writer
writer.write(command.data)
if not writer.is_closing():
writer.write(command.data)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.StartHook):
@ -393,7 +394,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
try:
writer = self.transports[connection].writer
assert writer
writer.write_eof()
if not writer.is_closing():
writer.write_eof()
except OSError:
# if we can't write to the socket anymore we presume it completely dead.
connection.state = ConnectionState.CLOSED

View File

@ -45,11 +45,16 @@ async def test_client_server():
server.resume_writing()
await server.drain()
assert not client_writer.is_closing()
assert not server_writer.is_closing()
assert await client_reader.read(MAX_DATAGRAM_SIZE) == b"msg4"
client_writer.close()
assert client_writer.is_closing()
await client_writer.wait_closed()
server_writer.close()
assert server_writer.is_closing()
await server_writer.wait_closed()
server.close()

View File

@ -36,6 +36,7 @@ deps =
types-requests==2.28.10
types-cryptography==3.3.23
types-pyOpenSSL==22.0.10
-e .[dev]
commands =
mypy {posargs}