diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index a316602c0..092990b10 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -1008,7 +1008,7 @@ class ConsoleMaster(flow.FlowMaster): self.statusbar.refresh_flow(c) def process_flow(self, f, r): - if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay(): + if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: f.intercept() else: r.reply() diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index a94f7ae4c..715bed801 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -172,7 +172,7 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): intercepting = f.intercepting, req_timestamp = f.request.timestamp_start, - req_is_replay = f.request.is_replay(), + req_is_replay = f.request.is_replay, req_method = f.request.method, req_acked = f.request.reply.acked, req_url = f.request.get_url(hostheader=hostheader), @@ -194,7 +194,7 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): d.update(dict( resp_code = f.response.code, - resp_is_replay = f.response.is_replay(), + resp_is_replay = f.response.is_replay, resp_acked = f.response.reply.acked, resp_clen = contentdesc, resp_rate = "{0}/s".format(rate), diff --git a/libmproxy/console/flowdetailview.py b/libmproxy/console/flowdetailview.py index a26e53083..8392537ee 100644 --- a/libmproxy/console/flowdetailview.py +++ b/libmproxy/console/flowdetailview.py @@ -74,9 +74,9 @@ class FlowDetailsView(urwid.ListBox): ) text.extend(common.format_keyvals(parts, key="key", val="text", indent=4)) - if self.flow.request.client_conn: + if self.flow.client_conn: text.append(urwid.Text([("head", "Client Connection:")])) - cc = self.flow.request.client_conn + cc = self.flow.client_conn parts = [ ["Address", "%s:%s"%tuple(cc.address)], ["Requests", "%s"%cc.requestcount], diff --git a/libmproxy/flow.py b/libmproxy/flow.py index b19714695..bf9171a70 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -174,7 +174,7 @@ class ClientPlaybackState: if self.flows and not self.current: n = self.flows.pop(0) n.request.reply = controller.DummyReply() - n.request.client_conn = None + n.client_conn = None self.current = master.handle_request(n.request) if not testing and not self.current.response: master.replay_request(self.current) # pragma: no cover @@ -249,9 +249,10 @@ class StickyCookieState: """ Returns a (domain, port, path) tuple. """ + raise NotImplementedError return ( m["domain"] or f.request.host, - f.request.port, + f.server_conn.address.port, m["path"] or "/" ) @@ -297,6 +298,7 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): + raise NotImplementedError if "authorization" in f.request.headers: self.hosts[f.request.host] = f.request.headers["authorization"] elif f.match(self.flt): @@ -665,11 +667,10 @@ class FlowMaster(controller.Master): return f def handle_request(self, r): - if False and r.is_live(): # FIXME + if r.flow.client_conn and r.flow.client_conn.wfile: app = self.apps.get(r) if app: - # FIXME: for the tcp proxy, use flow.client_conn.wfile - err = app.serve(r, r.wfile, **{"mitmproxy.master": self}) + err = app.serve(r, r.flow.client_conn.wfile, **{"mitmproxy.master": self}) if err: self.add_event("Error in wsgi app. %s"%err, "error") r.reply(KILL) diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index be60f3746..636e1b071 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -56,7 +56,7 @@ class decoded(object): class HTTPMessage(stateobject.SimpleStateObject): def __init__(self): - self.flow = None # Will usually set by backref mixin + self.flow = None # will usually be set by the flow backref mixin """@type: HTTPFlow""" def get_decoded_content(self): @@ -197,7 +197,7 @@ class HTTPRequest(HTTPMessage): timestamp_end: Timestamp indicating when request transmission ended """ def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content, - timestamp_start, timestamp_end, form_out=None): + timestamp_start=None, timestamp_end=None, form_out=None): assert isinstance(headers, ODictCaseless) or not headers HTTPMessage.__init__(self) @@ -758,7 +758,6 @@ class HTTPFlow(Flow): """ Continue with the flow - called after an intercept(). """ - assert self.intercepting if self.request: if not self.request.reply.acked: self.request.reply() diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index f77e097b6..f3fdd245a 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -41,6 +41,7 @@ class Error(stateobject.SimpleStateObject): @type msg: str @type timestamp: float """ + self.flow = None # will usually be set by the flow backref mixin self.msg = msg self.timestamp = timestamp or utils.timestamp() @@ -88,6 +89,9 @@ class Flow(stateobject.SimpleStateObject, _BackreferenceMixin): f._load_state(state) return f + def __eq__(self, other): + return self is other + def copy(self): f = copy.copy(self) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index feff2259f..53e3f575b 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -55,6 +55,10 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): if client_connection: # Eventually, this object is restored from state. We don't have a connection then. tcp.BaseHandler.__init__(self, client_connection, address, server) else: + self.connection = None + self.server = None + self.wfile = None + self.rfile = None self.address = None self.clientcert = None @@ -86,7 +90,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): @classmethod def _from_state(cls, state): - f = cls(None, None, None) + f = cls(None, tuple(), None) f._load_state(state) return f @@ -141,7 +145,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): @classmethod def _from_state(cls, state): - f = cls(None) + f = cls(tuple()) f._load_state(state) return f diff --git a/test/test_flow.py b/test/test_flow.py index 3e111fc15..5ae6c8d6a 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -2,7 +2,8 @@ import Queue, time, os.path from cStringIO import StringIO import email.utils from libmproxy import filt, protocol, controller, utils, tnetstring, proxy, flow -from libmproxy.protocol.primitives import Error +from libmproxy.protocol.primitives import Error, Flow +from libmproxy.protocol.http import decoded import tutils @@ -176,7 +177,8 @@ class TestFlow: f2 = f.copy() a = f._get_state() b = f2._get_state() - assert f == f2 + assert f._get_state() == f2._get_state() + assert not f == f2 assert not f is f2 assert f.request == f2.request assert not f.request is f2.request @@ -234,7 +236,8 @@ class TestFlow: assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state() f2 = f.copy() - assert f == f2 + assert f._get_state() == f2._get_state() + assert not f == f2 f2.error = Error("e2") assert not f == f2 f._load_state(f2._get_state()) @@ -340,7 +343,7 @@ class TestState: connect -> request -> response """ - bc = flow.ClientConnect(("address", 22)) + bc = tutils.tclient_conn() c = flow.State() req = tutils.treq(bc) @@ -359,6 +362,7 @@ class TestState: assert c.active_flow_count() == 1 unseen_resp = tutils.tresp() + unseen_resp.flow = None assert not c.add_response(unseen_resp) assert c.active_flow_count() == 1 @@ -370,8 +374,8 @@ class TestState: c = flow.State() req = tutils.treq() f = c.add_request(req) - e = Error("message") - assert c.add_error(e) + f.error = Error("message") + assert c.add_error(f.error) e = Error("message") assert not c.add_error(e) @@ -379,10 +383,9 @@ class TestState: c = flow.State() req = tutils.treq() f = c.add_request(req) - e = Error("message") + e = tutils.terr() c.set_limit("~e") assert not c.view - assert not c.view assert c.add_error(e) assert c.view @@ -460,7 +463,7 @@ class TestState: c.clear() c.load_flows(flows) - assert isinstance(c._flow_list[0], flow.Flow) + assert isinstance(c._flow_list[0], Flow) def test_accept_all(self): c = flow.State() @@ -595,9 +598,7 @@ class TestFlowMaster: #load second script assert not fm.load_script(tutils.test_data.path("scripts/all.py")) assert len(fm.scripts) == 2 - dc = flow.ClientDisconnect(req.flow.client_conn) - dc.reply = controller.DummyReply() - fm.handle_clientdisconnect(dc) + fm.handle_clientdisconnect(sc) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" @@ -607,7 +608,7 @@ class TestFlowMaster: assert len(fm.scripts) == 0 assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - err = Error("msg") + err = tutils.terr() err.reply = controller.DummyReply() fm.handle_error(err) assert fm.scripts[0].ns["log"][-1] == "error" @@ -621,7 +622,7 @@ class TestFlowMaster: f2 = fm.duplicate_flow(f) assert f2.response assert s.flow_count() == 2 - assert s.index(f2) + assert s.index(f2) == 1 def test_all(self): s = flow.State() @@ -639,15 +640,14 @@ class TestFlowMaster: assert s.flow_count() == 1 rx = tutils.tresp() + rx.flow = None assert not fm.handle_response(rx) - dc = flow.ClientDisconnect(req.flow.client_conn) - dc.reply = controller.DummyReply() - fm.handle_clientdisconnect(dc) + fm.handle_clientdisconnect(req.flow.client_conn) - err = Error("msg") - err.reply = controller.DummyReply() - fm.handle_error(err) + f.error = Error("msg") + f.error.reply = controller.DummyReply() + fm.handle_error(f.error) fm.load_script(tutils.test_data.path("scripts/a.py")) fm.shutdown() @@ -666,9 +666,9 @@ class TestFlowMaster: fm.tick(q) assert fm.state.flow_count() - err = Error("error") - err.reply = controller.DummyReply() - fm.handle_error(err) + f.error = Error("error") + f.error.reply = controller.DummyReply() + fm.handle_error(f.error) def test_server_playback(self): controller.should_exit = False @@ -771,20 +771,16 @@ class TestFlowMaster: assert r()[0].response - tf = tutils.tflow_full() + tf = tutils.tflow() fm.start_stream(file(p, "ab"), None) fm.handle_request(tf.request) fm.shutdown() assert not r()[1].response - class TestRequest: def test_simple(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() u = r.get_url() assert r.set_url(u) assert not r.set_url("") @@ -799,19 +795,11 @@ class TestRequest: assert r._assemble() assert r.size() == len(r._assemble()) - r.close = True - assert "connection: close" in r._assemble() - - assert r._assemble(True) - r.content = flow.CONTENT_MISSING - assert not r._assemble() + tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) def test_get_url(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.tflow().request assert r.get_url() == "https://host:22/" assert r.get_url(hostheader=True) == "https://host:22/" r.headers["Host"] = ["foo.com"] @@ -819,11 +807,10 @@ class TestRequest: assert r.get_url(hostheader=True) == "https://foo.com:22/" def test_path_components(self): - h = flow.ODictCaseless() - c = flow.ClientConnect(("addr", 2222)) - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r = tutils.treq() + r.path = "/" assert r.get_path_components() == [] - r = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/foo/bar", h, "content") + r.path = "/foo/bar" assert r.get_path_components() == ["foo", "bar"] q = flow.ODict() q["test"] = ["123"] @@ -839,10 +826,9 @@ class TestRequest: assert "%2F" in r.path def test_getset_form_urlencoded(self): - h = flow.ODictCaseless() - h["content-type"] = [flow.HDR_FORM_URLENCODED] d = flow.ODict([("one", "two"), ("three", "four")]) - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/", h, utils.urlencode(d.lst)) + r = tutils.treq(content=utils.urlencode(d.lst)) + r.headers["content-type"] = [protocol.http.HDR_FORM_URLENCODED] assert r.get_form_urlencoded() == d d = flow.ODict([("x", "y")]) @@ -855,19 +841,20 @@ class TestRequest: def test_getset_query(self): h = flow.ODictCaseless() - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/foo?x=y&a=b", h, "content") + r = tutils.treq() + r.path = "/foo?x=y&a=b" q = r.get_query() assert q.lst == [("x", "y"), ("a", "b")] - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/", h, "content") + r.path = "/" q = r.get_query() assert not q - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/?adsfa", h, "content") + r.path = "/?adsfa" q = r.get_query() assert q.lst == [("adsfa", "")] - r = flow.Request(None, (1, 1), "host", 22, "https", "GET", "/foo?x=y&a=b", h, "content") + r.path = "/foo?x=y&a=b" assert r.get_query() r.set_query(flow.ODict([])) assert not r.get_query() @@ -985,34 +972,20 @@ class TestRequest: print r._assemble_headers() assert result == 62 - def test_get_transmitted_size(self): - h = flow.ODictCaseless() - h["headername"] = ["headervalue"] - r = tutils.treq() - r.headers = h - result = r.get_transmitted_size() - assert result==len("content") - r.content = None - assert r.get_transmitted_size() == 0 - def test_get_content_type(self): h = flow.ODictCaseless() h["Content-Type"] = ["text/plain"] resp = tutils.tresp() resp.headers = h - assert resp.get_content_type()=="text/plain" + assert resp.headers.get_first("content-type") == "text/plain" class TestResponse: def test_simple(self): - h = flow.ODictCaseless() - h["test"] = ["test"] - c = flow.ClientConnect(("addr", 2222)) - req = flow.Request(c, (1, 1), "host", 22, "https", "GET", "/", h, "content") - resp = flow.Response(req, (1, 1), 200, "msg", h.copy(), "content", None) + f = tutils.tflow_full() + resp = f.response assert resp._assemble() assert resp.size() == len(resp._assemble()) - resp2 = resp.copy() assert resp2 == resp @@ -1021,7 +994,7 @@ class TestResponse: assert resp.size() == len(resp._assemble()) resp.content = flow.CONTENT_MISSING - assert not resp._assemble() + tutils.raises("Cannot assemble flow with CONTENT_MISSING", resp._assemble) def test_refresh(self): r = tutils.tresp() @@ -1147,7 +1120,7 @@ class TestResponse: h["Content-Type"] = ["text/plain"] resp = tutils.tresp() resp.headers = h - assert resp.headers.get_first("content-type")=="text/plain" + assert resp.headers.get_first("content-type") == "text/plain" class TestError: @@ -1195,13 +1168,13 @@ def test_decoded(): r.encode("gzip") assert r.headers["content-encoding"] assert r.content != "content" - with flow.decoded(r): + with decoded(r): assert not r.headers["content-encoding"] assert r.content == "content" assert r.headers["content-encoding"] assert r.content != "content" - with flow.decoded(r): + with decoded(r): r.content = "foo" assert r.content != "foo" diff --git a/test/test_proxy.py b/test/test_proxy.py index 737e4a923..41d41d0cc 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -22,14 +22,12 @@ class TestServerConnection: sc = proxy.ServerConnection((self.d.IFACE, self.d.port)) sc.connect() r = tutils.treq() + r.flow.server_conn = sc r.path = "/p/200:da" sc.send(r._assemble()) assert http.read_response(sc.rfile, r.method, 1000) assert self.d.last_log() - r.content = flow.CONTENT_MISSING - tutils.raises("incomplete request", sc.send, r._assemble()) - sc.finish() def test_terminate_error(self): diff --git a/test/test_script.py b/test/test_script.py index 025e9f377..7ee85f2c5 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -32,8 +32,8 @@ class TestScript: r = tutils.treq() fm.handle_request(r) assert fm.state.flow_count() == 2 - assert not fm.state.view[0].request.is_replay() - assert fm.state.view[1].request.is_replay() + assert not fm.state.view[0].request.is_replay + assert fm.state.view[1].request.is_replay def test_err(self): s = flow.State() diff --git a/test/test_server.py b/test/test_server.py index f542062d9..756492934 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -372,7 +372,6 @@ class TestTransparentResolveError(tservers.TransparentProxTest): class MasterIncomplete(tservers.TestMaster): def handle_request(self, m): - # FIXME: fails because of a ._assemble().splitlines() log statement. resp = tutils.tresp() resp.content = flow.CONTENT_MISSING m.reply(resp) diff --git a/test/tutils.py b/test/tutils.py index 0d3b94f42..10cd0eb93 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -5,6 +5,7 @@ from libmproxy.protocol import http if os.name != "nt": from libmproxy.console.flowview import FlowView from libmproxy.console import ConsoleState +from libmproxy.protocol.primitives import Error from netlib import certutils from nose.plugins.skip import SkipTest from mock import Mock @@ -27,6 +28,7 @@ def tclient_conn(): c.reply = controller.DummyReply() return c + def tserver_conn(): c = proxy.ServerConnection._from_state(dict( address=dict(address=("address", 22), use_ipv6=True), @@ -34,6 +36,7 @@ def tserver_conn(): cert=None )) c.reply = controller.DummyReply() + return c def treq(conn=None, content="content"): @@ -71,7 +74,7 @@ def terr(req=None): if not req: req = treq() f = req.flow - f.error = flow.Error("error") + f.error = Error("error") f.error.reply = controller.DummyReply() return f.error