diff --git a/libmproxy/dump.py b/libmproxy/dump.py index b02c84dc7..73ecc54d4 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -73,7 +73,8 @@ class DumpMaster(flow.FlowMaster): if options.server_replay: self.start_server_playback( self._readflow(options.server_replay), - options.kill, options.rheaders + options.kill, options.rheaders, + not options.keepserving ) if options.client_replay: diff --git a/libmproxy/flow.py b/libmproxy/flow.py index fb9bd4c72..957d53012 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -69,12 +69,12 @@ class ClientPlaybackState: class ServerPlaybackState: - def __init__(self, headers, flows): + def __init__(self, headers, flows, exit): """ headers: A case-insensitive list of request headers that should be included in request-response matching. """ - self.headers = headers + self.headers, self.exit = headers, exit self.fmap = {} for i in flows: if i.response: @@ -458,12 +458,12 @@ class FlowMaster(controller.Master): """ self.client_playback = ClientPlaybackState(flows, exit) - def start_server_playback(self, flows, kill, headers): + def start_server_playback(self, flows, kill, headers, exit): """ flows: A list of flows. kill: Boolean, should we kill requests not part of the replay? """ - self.server_playback = ServerPlaybackState(headers, flows) + self.server_playback = ServerPlaybackState(headers, flows, exit) self.kill_nonreplay = kill def do_server_playback(self, flow): @@ -492,6 +492,11 @@ class FlowMaster(controller.Master): if all(e): self.shutdown() self.client_playback.tick(self) + + if self.server_playback: + if self.server_playback.exit and self.server_playback.count() == 0: + self.shutdown() + controller.Master.tick(self, q) def handle_clientconnect(self, r): diff --git a/test/test_flow.py b/test/test_flow.py index 79e6dcc59..4090d483c 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -65,7 +65,7 @@ class uClientPlaybackState(libpry.AutoTree): class uServerPlaybackState(libpry.AutoTree): def test_hash(self): - s = flow.ServerPlaybackState(None, []) + s = flow.ServerPlaybackState(None, [], False) r = tutils.tflow() r2 = tutils.tflow() @@ -77,7 +77,7 @@ class uServerPlaybackState(libpry.AutoTree): assert s._hash(r) != s._hash(r2) def test_headers(self): - s = flow.ServerPlaybackState(["foo"], []) + s = flow.ServerPlaybackState(["foo"], [], False) r = tutils.tflow_full() r.request.headers["foo"] = ["bar"] r2 = tutils.tflow_full() @@ -98,7 +98,7 @@ class uServerPlaybackState(libpry.AutoTree): r2 = tutils.tflow_full() r2.request.headers["key"] = ["two"] - s = flow.ServerPlaybackState(None, [r, r2]) + s = flow.ServerPlaybackState(None, [r, r2], False) assert s.count() == 2 assert len(s.fmap.keys()) == 1 @@ -396,7 +396,7 @@ class uFlowMaster(libpry.AutoTree): f = tutils.tflow_full() pb = [tutils.tflow_full(), f] fm = flow.FlowMaster(None, s) - assert not fm.start_server_playback(pb, False, []) + assert not fm.start_server_playback(pb, False, [], False) assert not fm.start_client_playback(pb, False) q = Queue.Queue() @@ -417,14 +417,19 @@ class uFlowMaster(libpry.AutoTree): fm = flow.FlowMaster(None, s) assert not fm.do_server_playback(tutils.tflow()) - fm.start_server_playback(pb, False, []) + fm.start_server_playback(pb, False, [], False) assert fm.do_server_playback(tutils.tflow()) - fm.start_server_playback(pb, False, []) + fm.start_server_playback(pb, False, [], True) r = tutils.tflow() r.request.content = "gibble" assert not fm.do_server_playback(r) + assert fm.do_server_playback(tutils.tflow()) + q = Queue.Queue() + fm.tick(q) + assert fm._shutdown + def test_stickycookie(self): s = flow.State() fm = flow.FlowMaster(None, s)