A new interface for reply

Reply is now explicit - it's no longer a callable itself. Instead, we have:

    reply.kill()            - kill the flow
    reply.ack()             - ack, but don't send anything
    reply.send(message)     - send a response

This is part of an incremental move to detach reply from our flow objects,
and unify the script and handler interfaces.
This commit is contained in:
Aldo Cortesi 2016-06-08 10:44:20 +12:00
parent 982077ec31
commit a388ddfd78
6 changed files with 28 additions and 40 deletions

View File

@ -16,7 +16,7 @@ def request(context, flow):
"HTTP/1.1", 200, "OK",
Headers(Content_Type="text/html"),
"helloworld")
flow.reply(resp)
flow.reply.send(resp)
# Method 2: Redirect the request to a different server
if flow.request.pretty_host.endswith("example.org"):

View File

@ -134,5 +134,5 @@ def next_layer(context, next_layer):
# We don't intercept - reply with a pass-through layer and add a "skipped" entry.
context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info")
next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False)
next_layer.reply(next_layer_replacement)
next_layer.reply.send(next_layer_replacement)
context.tls_strategy.record_skipped(server_address)

View File

@ -145,10 +145,6 @@ class Channel(object):
self.q.put((mtype, m))
# Special value to distinguish the case where no reply was sent
NO_REPLY = object()
def handler(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
@ -199,22 +195,19 @@ class Reply(object):
self.handled = False
def ack(self):
self(NO_REPLY)
self.send(self.obj)
def kill(self):
self(exceptions.Kill)
self.send(exceptions.Kill)
def take(self):
self.taken = True
def __call__(self, msg=NO_REPLY):
def send(self, msg):
if self.acked:
raise exceptions.ControlException("Message already acked.")
self.acked = True
if msg is NO_REPLY:
self.q.put(self.obj)
else:
self.q.put(msg)
self.q.put(msg)
def __del__(self):
if not self.acked:
@ -233,13 +226,13 @@ class DummyReply(object):
self.handled = False
def kill(self):
self()
self.send(None)
def ack(self):
self()
self.send(None)
def take(self):
self.taken = True
def __call__(self, msg=False):
def send(self, msg):
self.acked = True

View File

@ -4,6 +4,7 @@ offload computations from mitmproxy's main master thread.
"""
from __future__ import absolute_import, print_function, division
from mitmproxy import controller
import threading
@ -14,15 +15,15 @@ class ReplyProxy(object):
self.script_thread = script_thread
self.master_reply = None
def __call__(self, *args):
def send(self, message):
if self.master_reply is None:
self.master_reply = args
self.master_reply = message
self.script_thread.start()
return
self.reply_func(*args)
self.reply_func(message)
def done(self):
self.reply_func(*self.master_reply)
self.reply_func.send(self.master_reply)
def __getattr__(self, k):
return getattr(self.reply_func, k)
@ -49,17 +50,11 @@ class ScriptThread(threading.Thread):
def concurrent(fn):
if fn.__name__ in (
"request",
"response",
"error",
"clientconnect",
"serverconnect",
"clientdisconnect",
"next_layer"):
def _concurrent(ctx, obj):
_handle_concurrent_reply(fn, obj, ctx, obj)
if fn.__name__ not in controller.Events:
raise NotImplementedError(
"Concurrent decorator not supported for '%s' method." % fn.__name__
)
return _concurrent
raise NotImplementedError(
"Concurrent decorator not supported for '%s' method." % fn.__name__)
def _concurrent(ctx, obj):
_handle_concurrent_reply(fn, obj, ctx, obj)
return _concurrent

View File

@ -66,7 +66,7 @@ class TestChannel(object):
def reply():
m, obj = q.get()
assert m == "test"
obj.reply(42)
obj.reply.send(42)
Thread(target=reply).start()
@ -86,7 +86,7 @@ class TestDummyReply(object):
def test_simple(self):
reply = controller.DummyReply()
assert not reply.acked
reply()
reply.ack()
assert reply.acked
@ -94,16 +94,16 @@ class TestReply(object):
def test_simple(self):
reply = controller.Reply(42)
assert not reply.acked
reply("foo")
reply.send("foo")
assert reply.acked
assert reply.q.get() == "foo"
def test_default(self):
reply = controller.Reply(42)
reply()
reply.ack()
assert reply.q.get() == 42
def test_reply_none(self):
reply = controller.Reply(42)
reply(None)
reply.send(None)
assert reply.q.get() is None

View File

@ -743,7 +743,7 @@ class MasterFakeResponse(tservers.TestMaster):
@controller.handler
def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
f.reply(resp)
f.reply.send(resp)
class TestFakeResponse(tservers.HTTPProxyTest):
@ -819,7 +819,7 @@ class MasterIncomplete(tservers.TestMaster):
def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
resp.content = None
f.reply(resp)
f.reply.send(resp)
class TestIncompleteResponse(tservers.HTTPProxyTest):