From 1089a52f3d16c4fef504586cae18a5d324e8d75c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 21 Jun 2012 10:56:30 +1200 Subject: [PATCH] Disconnect, rest refactoring. --- libpathod/pathod.py | 36 ++++++++++++++++++------------------ libpathod/rparse.py | 13 +++++++------ libpathod/utils.py | 25 ------------------------- test/test_pathod.py | 10 +++++----- test/test_rparse.py | 18 +++++++++--------- 5 files changed, 39 insertions(+), 63 deletions(-) diff --git a/libpathod/pathod.py b/libpathod/pathod.py index 8a29b9cb3..e0a0764fd 100644 --- a/libpathod/pathod.py +++ b/libpathod/pathod.py @@ -18,6 +18,10 @@ class PathodHandler(tcp.BaseHandler): return None method, path, httpversion = protocol.parse_init_http(line) + headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) + content = protocol.read_http_body_request( + self.rfile, self.wfile, headers, httpversion, None + ) if path.startswith(self.server.prefix): spec = urllib.unquote(path)[len(self.server.prefix):] try: @@ -27,24 +31,20 @@ class PathodHandler(tcp.BaseHandler): 800, "Error parsing response spec: %s\n"%v.msg + v.marked() ) - presp.serve(self.wfile) - self.finish() - return - - headers = odict.ODictCaseless(protocol.read_headers(self.rfile)) - content = protocol.read_http_body_request( - self.rfile, self.wfile, headers, httpversion, None - ) - cc = wsgi.ClientConn(self.client_address) - req = wsgi.Request(cc, "http", method, path, headers, content) - sn = self.connection.getsockname() - app = wsgi.WSGIAdaptor( - self.server.app, - sn[0], - self.server.port, - version.NAMEVERSION - ) - app.serve(req, self.wfile) + ret = presp.serve(self.wfile) + if ret["disconnect"]: + self.close() + else: + cc = wsgi.ClientConn(self.client_address) + req = wsgi.Request(cc, "http", method, path, headers, content) + sn = self.connection.getsockname() + app = wsgi.WSGIAdaptor( + self.server.app, + sn[0], + self.server.port, + version.NAMEVERSION + ) + app.serve(req, self.wfile) class Pathod(tcp.TCPServer): diff --git a/libpathod/rparse.py b/libpathod/rparse.py index 677c6b546..470845200 100644 --- a/libpathod/rparse.py +++ b/libpathod/rparse.py @@ -390,6 +390,9 @@ class Response: return ret def write_values(self, fp, vals, actions, sofar=0, skip=0, blocksize=BLOCKSIZE): + """ + Return True if connection should disconnect. + """ while vals: part = vals.pop() for i in range(skip, len(part), blocksize): @@ -401,18 +404,15 @@ class Response: if p[1] == "pause": fp.write(d[:offset]) time.sleep(p[2]) - self.write_values( + return self.write_values( fp, vals, actions, sofar=sofar+offset, skip=i+offset, blocksize=blocksize ) - return elif p[1] == "disconnect": fp.write(d[:offset]) - fp.finish() - fp.connection.stream.close() - return + return True fp.write(d) sofar += len(d) skip = 0 @@ -447,9 +447,10 @@ class Response: vals.reverse() actions = self.ready_actions(self.length(), self.actions) actions.reverse() - self.write_values(fp, vals, actions[:]) + disconnect = self.write_values(fp, vals, actions[:]) duration = time.time() - started return dict( + disconnect = disconnect, started = started, duration = duration, actions = actions, diff --git a/libpathod/utils.py b/libpathod/utils.py index 0e3bda9df..f421b8a6a 100644 --- a/libpathod/utils.py +++ b/libpathod/utils.py @@ -4,31 +4,6 @@ import rparse class AnchorError(Exception): pass -class Sponge: - def __getattr__(self, x): - return Sponge() - - def __call__(self, *args, **kwargs): - pass - - -class DummyRequest: - connection = Sponge() - def __init__(self): - self.buf = [] - - def write(self, d, callback=None): - self.buf.append(str(d)) - if callback: - callback() - - def getvalue(self): - return "".join(self.buf) - - def finish(self): - return - - def parse_anchor_spec(s, settings): """ For now, this is very simple, and you can't have an '=' in your regular diff --git a/test/test_pathod.py b/test/test_pathod.py index 3fd2388ad..36a2d0906 100644 --- a/test/test_pathod.py +++ b/test/test_pathod.py @@ -1,7 +1,6 @@ from libpathod import pathod -from tornado import httpserver -class TestApplication: +class _TestApplication: def test_anchors(self): a = pathod.PathodApp(staticdir=None) a.add_anchor("/foo", "200") @@ -30,6 +29,7 @@ class TestApplication: assert not a.log_by_id(0) -def test_make_server(): - app = pathod.PathodApp() - assert pathod.make_server(app, 0, "127.0.0.1", None) +class TestPathod: + def test_instantiation(self): + pathod.Pathod(("127.0.0.1", 0)) + diff --git a/test/test_rparse.py b/test/test_rparse.py index f0db75fdb..0813f22e4 100644 --- a/test/test_rparse.py +++ b/test/test_rparse.py @@ -1,4 +1,4 @@ -import os +import os, cStringIO from libpathod import rparse, utils import tutils @@ -131,7 +131,7 @@ class TestMisc: assert r.msg.val == "Unknown code" def test_internal_response(self): - d = utils.DummyRequest() + d = cStringIO.StringIO() s = rparse.InternalResponse(400, "foo") s.serve(d) @@ -245,7 +245,7 @@ class TestResponse: def test_write_values_disconnects(self): r = self.dummy_response() - s = utils.DummyRequest() + s = cStringIO.StringIO() tst = "foo"*100 r.write_values(s, [tst], [(0, "disconnect")], blocksize=5) assert not s.getvalue() @@ -254,7 +254,7 @@ class TestResponse: tst = "foo"*1025 r = rparse.parse({}, "400'msg'") - s = utils.DummyRequest() + s = cStringIO.StringIO() r.write_values(s, [tst], []) assert s.getvalue() == tst @@ -263,29 +263,29 @@ class TestResponse: r = rparse.parse({}, "400'msg'") for i in range(2, 10): - s = utils.DummyRequest() + s = cStringIO.StringIO() r.write_values(s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i) assert s.getvalue() == tst for i in range(2, 10): - s = utils.DummyRequest() + s = cStringIO.StringIO() r.write_values(s, [tst], [(1, "pause", 0)], blocksize=i) assert s.getvalue() == tst tst = ["".join(str(i) for i in range(10))]*5 for i in range(2, 10): - s = utils.DummyRequest() + s = cStringIO.StringIO() r.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i) assert s.getvalue() == "".join(tst) def test_render(self): - s = utils.DummyRequest() + s = cStringIO.StringIO() r = rparse.parse({}, "400'msg'") assert r.serve(s) def test_length(self): def testlen(x): - s = utils.DummyRequest() + s = cStringIO.StringIO() x.serve(s) assert x.length() == len(s.getvalue()) testlen(rparse.parse({}, "400'msg'"))