fix bugs, fix tests
This commit is contained in:
parent
63844df343
commit
a7058e2a3c
|
@ -199,11 +199,12 @@ class StatusBar(urwid.WidgetWrap):
|
|||
r.append("[%s]" % (":".join(opts)))
|
||||
|
||||
if self.master.server.config.mode in ["reverse", "upstream"]:
|
||||
dst = self.master.server.config.mode.dst
|
||||
scheme = "https" if dst[0] else "http"
|
||||
if dst[1] != dst[0]:
|
||||
scheme += "2https" if dst[1] else "http"
|
||||
r.append("[dest:%s]" % utils.unparse_url(scheme, *dst[2:]))
|
||||
dst = self.master.server.config.upstream_server
|
||||
r.append("[dest:%s]" % netlib.utils.unparse_url(
|
||||
dst.scheme,
|
||||
dst.address.host,
|
||||
dst.address.port
|
||||
))
|
||||
if self.master.scripts:
|
||||
r.append("[")
|
||||
r.append(("heading_key", "s"))
|
||||
|
|
|
@ -40,6 +40,7 @@ class _HttpLayer(Layer):
|
|||
def send_response(self, response):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class _StreamingHttpLayer(_HttpLayer):
|
||||
supports_streaming = True
|
||||
|
||||
|
@ -58,7 +59,6 @@ class _StreamingHttpLayer(_HttpLayer):
|
|||
|
||||
|
||||
class Http1Layer(_StreamingHttpLayer):
|
||||
|
||||
def __init__(self, ctx, mode):
|
||||
super(Http1Layer, self).__init__(ctx)
|
||||
self.mode = mode
|
||||
|
@ -105,12 +105,12 @@ class Http1Layer(_StreamingHttpLayer):
|
|||
|
||||
def send_response_headers(self, response):
|
||||
h = self.client_protocol._assemble_response_first_line(response)
|
||||
self.client_conn.wfile.write(h+"\r\n")
|
||||
self.client_conn.wfile.write(h + "\r\n")
|
||||
h = self.client_protocol._assemble_response_headers(
|
||||
response,
|
||||
preserve_transfer_encoding=True
|
||||
)
|
||||
self.client_conn.send(h+"\r\n")
|
||||
self.client_conn.send(h + "\r\n")
|
||||
|
||||
def send_response_body(self, response, chunks):
|
||||
if self.client_protocol.has_chunked_encoding(response.headers):
|
||||
|
@ -142,8 +142,10 @@ class Http2Layer(_HttpLayer):
|
|||
def __init__(self, ctx, mode):
|
||||
super(Http2Layer, self).__init__(ctx)
|
||||
self.mode = mode
|
||||
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True,
|
||||
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
|
||||
def read_request(self):
|
||||
request = HTTPRequest.from_protocol(
|
||||
|
@ -172,17 +174,20 @@ class Http2Layer(_HttpLayer):
|
|||
|
||||
def connect(self):
|
||||
self.ctx.connect()
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol.perform_connection_preface()
|
||||
|
||||
def reconnect(self):
|
||||
self.ctx.reconnect()
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol.perform_connection_preface()
|
||||
|
||||
def set_server(self, *args, **kwargs):
|
||||
self.ctx.set_server(*args, **kwargs)
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||
self.server_protocol.perform_connection_preface()
|
||||
|
||||
def __call__(self):
|
||||
|
@ -264,7 +269,10 @@ class UpstreamConnectLayer(Layer):
|
|||
def __init__(self, ctx, connect_request):
|
||||
super(UpstreamConnectLayer, self).__init__(ctx)
|
||||
self.connect_request = connect_request
|
||||
self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx)
|
||||
self.server_conn = ConnectServerConnection(
|
||||
(connect_request.host, connect_request.port),
|
||||
self.ctx
|
||||
)
|
||||
|
||||
def __call__(self):
|
||||
layer = self.ctx.next_layer(self)
|
||||
|
@ -280,6 +288,9 @@ class UpstreamConnectLayer(Layer):
|
|||
def reconnect(self):
|
||||
self.ctx.reconnect()
|
||||
self.send_request(self.connect_request)
|
||||
resp = self.read_response("CONNECT")
|
||||
if resp.code != 200:
|
||||
raise ProtocolException("Reconnect: Upstream server refuses CONNECT request")
|
||||
|
||||
def set_server(self, address, server_tls=None, sni=None, depth=1):
|
||||
if depth == 1:
|
||||
|
@ -290,7 +301,7 @@ class UpstreamConnectLayer(Layer):
|
|||
self.connect_request.port = address.port
|
||||
self.server_conn.address = address
|
||||
else:
|
||||
self.ctx.set_server(address, server_tls, sni, depth-1)
|
||||
self.ctx.set_server(address, server_tls, sni, depth - 1)
|
||||
|
||||
|
||||
class HttpLayer(Layer):
|
||||
|
@ -413,10 +424,10 @@ class HttpLayer(Layer):
|
|||
# First send the headers and then transfer the response incrementally
|
||||
self.send_response_headers(flow.response)
|
||||
chunks = self.read_response_body(
|
||||
flow.response.headers,
|
||||
flow.request.method,
|
||||
flow.response.code,
|
||||
max_chunk_size=4096
|
||||
flow.response.headers,
|
||||
flow.request.method,
|
||||
flow.response.code,
|
||||
max_chunk_size=4096
|
||||
)
|
||||
if callable(flow.response.stream):
|
||||
chunks = flow.response.stream(chunks)
|
||||
|
@ -521,7 +532,8 @@ class HttpLayer(Layer):
|
|||
# If there's not TlsLayer below which could catch the exception,
|
||||
# TLS will not be established.
|
||||
if tls and not self.server_conn.tls_established:
|
||||
raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.")
|
||||
raise ProtocolException(
|
||||
"Cannot upgrade to SSL, no TLS layer on the protocol stack.")
|
||||
else:
|
||||
if not self.server_conn:
|
||||
self.connect()
|
||||
|
@ -542,7 +554,8 @@ class HttpLayer(Layer):
|
|||
|
||||
def validate_request(self, request):
|
||||
if request.form_in == "absolute" and request.scheme != "http":
|
||||
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
||||
self.send_response(
|
||||
make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
||||
raise HttpException("Invalid request scheme: %s" % request.scheme)
|
||||
|
||||
expected_request_forms = {
|
||||
|
@ -570,7 +583,11 @@ class HttpLayer(Layer):
|
|||
self.send_response(make_error_response(
|
||||
407,
|
||||
"Proxy Authentication Required",
|
||||
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
|
||||
odict.ODictCaseless(
|
||||
[
|
||||
[k, v] for k, v in
|
||||
self.config.authenticator.auth_challenge_headers().items()
|
||||
])
|
||||
))
|
||||
raise InvalidCredentials("Proxy Authentication Required")
|
||||
|
||||
|
@ -614,6 +631,9 @@ class RequestReplayThread(threading.Thread):
|
|||
if r.scheme == "https":
|
||||
connect_request = make_connect_request((r.host, r.port))
|
||||
server.send(protocol.assemble(connect_request))
|
||||
resp = protocol.read_response("CONNECT")
|
||||
if resp.code != 200:
|
||||
raise HttpError(502, "Upstream server refuses CONNECT request")
|
||||
server.establish_ssl(
|
||||
self.config.clientcerts,
|
||||
sni=self.flow.server_conn.sni
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
import argparse
|
||||
from libmproxy import cmdline
|
||||
from libmproxy.proxy import ProxyConfig, process_proxy_options
|
||||
from libmproxy.proxy.connection import ServerConnection
|
||||
from libmproxy.proxy.primitives import ProxyError
|
||||
from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler
|
||||
from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler2
|
||||
import tutils
|
||||
from libpathod import test
|
||||
from netlib import http, tcp
|
||||
|
@ -175,8 +174,10 @@ class TestDummyServer:
|
|||
class TestConnectionHandler:
|
||||
def test_fatal_error(self):
|
||||
config = mock.Mock()
|
||||
config.mode.get_upstream_server.side_effect = RuntimeError
|
||||
c = ConnectionHandler(
|
||||
root_layer = mock.Mock()
|
||||
root_layer.side_effect = RuntimeError
|
||||
config.mode.return_value = root_layer
|
||||
c = ConnectionHandler2(
|
||||
config,
|
||||
mock.MagicMock(),
|
||||
("127.0.0.1",
|
||||
|
|
|
@ -68,7 +68,7 @@ class CommonMixin:
|
|||
# SSL with the upstream proxy.
|
||||
rt = self.master.replay_request(l, block=True)
|
||||
assert not rt
|
||||
if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl:
|
||||
if isinstance(self, tservers.HTTPUpstreamProxTest):
|
||||
assert l.response.code == 502
|
||||
else:
|
||||
assert l.error
|
||||
|
@ -506,7 +506,7 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin):
|
|||
p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None)
|
||||
p.connect()
|
||||
r = p.request("get:/")
|
||||
assert r.status_code == 400
|
||||
assert r.status_code == 502
|
||||
|
||||
|
||||
class TestProxy(tservers.HTTPProxTest):
|
||||
|
@ -724,9 +724,9 @@ class TestStreamRequest(tservers.HTTPProxTest):
|
|||
assert resp.headers["Transfer-Encoding"][0] == 'chunked'
|
||||
assert resp.status_code == 200
|
||||
|
||||
chunks = list(
|
||||
content for _, content, _ in protocol.read_http_body_chunked(
|
||||
resp.headers, None, "GET", 200, False))
|
||||
chunks = list(protocol.read_http_body_chunked(
|
||||
resp.headers, None, "GET", 200, False
|
||||
))
|
||||
assert chunks == ["this", "isatest", ""]
|
||||
|
||||
connection.close()
|
||||
|
@ -959,6 +959,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
|
|||
|
||||
p = self.pathoc()
|
||||
req = p.request("get:'/p/418:b\"content\"'")
|
||||
assert req.content == "content"
|
||||
assert req.status_code == 418
|
||||
|
||||
assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
|
||||
# CONNECT, failing request,
|
||||
assert self.chain[0].tmaster.state.flow_count() == 4
|
||||
|
@ -967,8 +970,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
|
|||
assert self.chain[1].tmaster.state.flow_count() == 2
|
||||
# (doesn't store (repeated) CONNECTs from chain[0]
|
||||
# as it is a regular proxy)
|
||||
assert req.content == "content"
|
||||
assert req.status_code == 418
|
||||
|
||||
|
||||
assert not self.chain[1].tmaster.state.flows[0].response # killed
|
||||
assert self.chain[1].tmaster.state.flows[1].response
|
||||
|
|
|
@ -181,22 +181,24 @@ class TResolver:
|
|||
def original_addr(self, sock):
|
||||
return ("127.0.0.1", self.port)
|
||||
|
||||
|
||||
class TransparentProxTest(ProxTestBase):
|
||||
ssl = None
|
||||
resolver = TResolver
|
||||
|
||||
@classmethod
|
||||
@mock.patch("libmproxy.platform.resolver")
|
||||
def setupAll(cls, _):
|
||||
def setupAll(cls):
|
||||
super(TransparentProxTest, cls).setupAll()
|
||||
if cls.ssl:
|
||||
ports = [cls.server.port, cls.server2.port]
|
||||
else:
|
||||
ports = []
|
||||
cls.config.mode = TransparentProxyMode(
|
||||
cls.resolver(cls.server.port),
|
||||
ports)
|
||||
|
||||
cls._resolver = mock.patch(
|
||||
"libmproxy.platform.resolver",
|
||||
new=lambda: cls.resolver(cls.server.port)
|
||||
)
|
||||
cls._resolver.start()
|
||||
|
||||
@classmethod
|
||||
def teardownAll(cls):
|
||||
cls._resolver.stop()
|
||||
super(TransparentProxTest, cls).teardownAll()
|
||||
|
||||
@classmethod
|
||||
def get_proxy_config(cls):
|
||||
|
@ -270,48 +272,6 @@ class SocksModeTest(HTTPProxTest):
|
|||
d["mode"] = "socks5"
|
||||
return d
|
||||
|
||||
class SpoofModeTest(ProxTestBase):
|
||||
ssl = None
|
||||
|
||||
@classmethod
|
||||
def get_proxy_config(cls):
|
||||
d = ProxTestBase.get_proxy_config()
|
||||
d["upstream_server"] = None
|
||||
d["mode"] = "spoof"
|
||||
return d
|
||||
|
||||
def pathoc(self, sni=None):
|
||||
"""
|
||||
Returns a connected Pathoc instance.
|
||||
"""
|
||||
p = libpathod.pathoc.Pathoc(
|
||||
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
|
||||
)
|
||||
p.connect()
|
||||
return p
|
||||
|
||||
|
||||
class SSLSpoofModeTest(ProxTestBase):
|
||||
ssl = True
|
||||
|
||||
@classmethod
|
||||
def get_proxy_config(cls):
|
||||
d = ProxTestBase.get_proxy_config()
|
||||
d["upstream_server"] = None
|
||||
d["mode"] = "sslspoof"
|
||||
d["spoofed_ssl_port"] = 443
|
||||
return d
|
||||
|
||||
def pathoc(self, sni=None):
|
||||
"""
|
||||
Returns a connected Pathoc instance.
|
||||
"""
|
||||
p = libpathod.pathoc.Pathoc(
|
||||
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
|
||||
)
|
||||
p.connect()
|
||||
return p
|
||||
|
||||
|
||||
class ChainProxTest(ProxTestBase):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue