Merge pull request #5604 from mhils/mypy-and-closing
Re-add `is_closing` check, fixup mypy.
This commit is contained in:
commit
810401f454
|
@ -314,7 +314,7 @@ class Dumper:
|
|||
|
||||
def format_websocket_error(self, websocket: WebSocketData) -> str:
|
||||
try:
|
||||
ret = CloseReason(websocket.close_code).name
|
||||
ret = CloseReason(websocket.close_code).name # type: ignore
|
||||
except ValueError:
|
||||
ret = f"UNKNOWN_ERROR={websocket.close_code}"
|
||||
if websocket.close_reason:
|
||||
|
@ -362,8 +362,8 @@ class Dumper:
|
|||
|
||||
desc = f"DNS {opcode} ({type})"
|
||||
desc_color = {
|
||||
"DNS QUERY (A)": "green",
|
||||
"DNS QUERY (AAAA)": "magenta",
|
||||
"A": "green",
|
||||
"AAAA": "magenta",
|
||||
}.get(type, "red")
|
||||
desc = self.style(desc, fg=desc_color)
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import functools
|
|||
import inspect
|
||||
import logging
|
||||
|
||||
import pyparsing
|
||||
import sys
|
||||
import textwrap
|
||||
import types
|
||||
|
@ -194,7 +195,7 @@ class CommandManager:
|
|||
Parse a possibly partial command. Return a sequence of ParseResults and a sequence of remainder type help items.
|
||||
"""
|
||||
|
||||
parts: list[str] = command_lexer.expr.parseString(cmdstr, parseAll=True)
|
||||
parts: pyparsing.ParseResults = command_lexer.expr.parseString(cmdstr, parseAll=True)
|
||||
|
||||
parsed: list[ParseResult] = []
|
||||
next_params: list[CommandParameter] = [
|
||||
|
|
|
@ -175,8 +175,12 @@ class DatagramWriter:
|
|||
"""
|
||||
self._transport = transport
|
||||
self._remote_addr = remote_addr
|
||||
self._reader = reader
|
||||
self._closed = asyncio.Event() if reader is not None else None
|
||||
if reader is not None:
|
||||
self._reader = reader
|
||||
self._closed = asyncio.Event()
|
||||
else:
|
||||
self._reader = None
|
||||
self._closed = None
|
||||
|
||||
@property
|
||||
def _protocol(self) -> DrainableDatagramProtocol | udp_wireguard.WireGuardDatagramTransport:
|
||||
|
@ -199,9 +203,15 @@ class DatagramWriter:
|
|||
self._transport.close()
|
||||
else:
|
||||
self._closed.set()
|
||||
if self._reader is not None:
|
||||
assert self._reader
|
||||
self._reader.feed_eof()
|
||||
|
||||
def is_closing(self) -> bool:
|
||||
if self._closed is None:
|
||||
return self._transport.is_closing()
|
||||
else:
|
||||
return self._closed.is_set()
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
if self._closed is None:
|
||||
await self._protocol.wait_closed()
|
||||
|
|
|
@ -141,7 +141,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
|
|||
def make_pipe(self) -> layer.CommandGenerator[None]:
|
||||
self.state = self.passthrough
|
||||
if self.buf:
|
||||
already_received = self.buf.maybe_extract_at_most(len(self.buf))
|
||||
already_received = self.buf.maybe_extract_at_most(len(self.buf)) or b""
|
||||
# Some clients send superfluous newlines after CONNECT, we want to eat those.
|
||||
already_received = already_received.lstrip(b"\r\n")
|
||||
if already_received:
|
||||
|
@ -264,11 +264,10 @@ class Http1Server(Http1Connection):
|
|||
if isinstance(event, events.DataReceived):
|
||||
request_head = self.buf.maybe_extract_lines()
|
||||
if request_head:
|
||||
request_head = [
|
||||
bytes(x) for x in request_head
|
||||
] # TODO: Make url.parse compatible with bytearrays
|
||||
try:
|
||||
self.request = http1.read_request_head(request_head)
|
||||
self.request = http1.read_request_head(
|
||||
[bytes(x) for x in request_head]
|
||||
)
|
||||
if self.context.options.validate_inbound_headers:
|
||||
http1.validate_headers(self.request.headers)
|
||||
expected_body_size = http1.expected_http_body_size(self.request)
|
||||
|
@ -388,11 +387,10 @@ class Http1Client(Http1Connection):
|
|||
|
||||
response_head = self.buf.maybe_extract_lines()
|
||||
if response_head:
|
||||
response_head = [
|
||||
bytes(x) for x in response_head
|
||||
] # TODO: Make url.parse compatible with bytearrays
|
||||
try:
|
||||
self.response = http1.read_response_head(response_head)
|
||||
self.response = http1.read_response_head(
|
||||
[bytes(x) for x in response_head]
|
||||
)
|
||||
if self.context.options.validate_inbound_headers:
|
||||
http1.validate_headers(self.response.headers)
|
||||
expected_size = http1.expected_http_body_size(
|
||||
|
|
|
@ -75,11 +75,10 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
|
|||
self.buf += data
|
||||
response_head = self.buf.maybe_extract_lines()
|
||||
if response_head:
|
||||
response_head = [
|
||||
bytes(x) for x in response_head
|
||||
] # TODO: Make url.parse compatible with bytearrays
|
||||
try:
|
||||
response = http1.read_response_head(response_head)
|
||||
response = http1.read_response_head([
|
||||
bytes(x) for x in response_head
|
||||
])
|
||||
except ValueError as e:
|
||||
proxyaddr = human.format_address(self.tunnel_connection.address)
|
||||
yield commands.Log(f"{proxyaddr}: {e}")
|
||||
|
|
|
@ -144,10 +144,10 @@ class ServerInstance(Generic[M], metaclass=ABCMeta):
|
|||
)
|
||||
handler.layer = self.make_top_layer(handler.layer.context)
|
||||
if isinstance(self.mode, mode_specs.TransparentMode):
|
||||
socket = writer.get_extra_info("socket")
|
||||
s = cast(socket.socket, writer.get_extra_info("socket"))
|
||||
try:
|
||||
assert platform.original_addr
|
||||
original_dst = platform.original_addr(socket)
|
||||
original_dst = platform.original_addr(s)
|
||||
except Exception as e:
|
||||
logger.error(f"Transparent mode failure: {e!r}")
|
||||
return
|
||||
|
@ -390,6 +390,7 @@ class WireGuardServerInstance(ServerInstance[mode_specs.WireGuardMode]):
|
|||
await self.handle_tcp_connection(stream, stream)
|
||||
|
||||
def wg_handle_udp_datagram(self, data: bytes, remote_addr: Address, local_addr: Address) -> None:
|
||||
assert self._server is not None
|
||||
transport = WireGuardDatagramTransport(self._server, local_addr, remote_addr)
|
||||
self.handle_udp_datagram(
|
||||
transport,
|
||||
|
|
|
@ -367,7 +367,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
elif isinstance(command, commands.SendData):
|
||||
writer = self.transports[command.connection].writer
|
||||
assert writer
|
||||
writer.write(command.data)
|
||||
if not writer.is_closing():
|
||||
writer.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
self.close_connection(command.connection, command.half_close)
|
||||
elif isinstance(command, commands.StartHook):
|
||||
|
@ -393,7 +394,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||
try:
|
||||
writer = self.transports[connection].writer
|
||||
assert writer
|
||||
writer.write_eof()
|
||||
if not writer.is_closing():
|
||||
writer.write_eof()
|
||||
except OSError:
|
||||
# if we can't write to the socket anymore we presume it completely dead.
|
||||
connection.state = ConnectionState.CLOSED
|
||||
|
|
|
@ -45,11 +45,16 @@ async def test_client_server():
|
|||
server.resume_writing()
|
||||
await server.drain()
|
||||
|
||||
assert not client_writer.is_closing()
|
||||
assert not server_writer.is_closing()
|
||||
|
||||
assert await client_reader.read(MAX_DATAGRAM_SIZE) == b"msg4"
|
||||
client_writer.close()
|
||||
assert client_writer.is_closing()
|
||||
await client_writer.wait_closed()
|
||||
|
||||
server_writer.close()
|
||||
assert server_writer.is_closing()
|
||||
await server_writer.wait_closed()
|
||||
|
||||
server.close()
|
||||
|
|
Loading…
Reference in New Issue