Disconnect, rest refactoring.

This commit is contained in:
Aldo Cortesi 2012-06-21 10:56:30 +12:00
parent de00497b40
commit 1089a52f3d
5 changed files with 39 additions and 63 deletions

View File

@ -18,6 +18,10 @@ class PathodHandler(tcp.BaseHandler):
return None
method, path, httpversion = protocol.parse_init_http(line)
headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
content = protocol.read_http_body_request(
self.rfile, self.wfile, headers, httpversion, None
)
if path.startswith(self.server.prefix):
spec = urllib.unquote(path)[len(self.server.prefix):]
try:
@ -27,24 +31,20 @@ class PathodHandler(tcp.BaseHandler):
800,
"Error parsing response spec: %s\n"%v.msg + v.marked()
)
presp.serve(self.wfile)
self.finish()
return
headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
content = protocol.read_http_body_request(
self.rfile, self.wfile, headers, httpversion, None
)
cc = wsgi.ClientConn(self.client_address)
req = wsgi.Request(cc, "http", method, path, headers, content)
sn = self.connection.getsockname()
app = wsgi.WSGIAdaptor(
self.server.app,
sn[0],
self.server.port,
version.NAMEVERSION
)
app.serve(req, self.wfile)
ret = presp.serve(self.wfile)
if ret["disconnect"]:
self.close()
else:
cc = wsgi.ClientConn(self.client_address)
req = wsgi.Request(cc, "http", method, path, headers, content)
sn = self.connection.getsockname()
app = wsgi.WSGIAdaptor(
self.server.app,
sn[0],
self.server.port,
version.NAMEVERSION
)
app.serve(req, self.wfile)
class Pathod(tcp.TCPServer):

View File

@ -390,6 +390,9 @@ class Response:
return ret
def write_values(self, fp, vals, actions, sofar=0, skip=0, blocksize=BLOCKSIZE):
"""
Return True if connection should disconnect.
"""
while vals:
part = vals.pop()
for i in range(skip, len(part), blocksize):
@ -401,18 +404,15 @@ class Response:
if p[1] == "pause":
fp.write(d[:offset])
time.sleep(p[2])
self.write_values(
return self.write_values(
fp, vals, actions,
sofar=sofar+offset,
skip=i+offset,
blocksize=blocksize
)
return
elif p[1] == "disconnect":
fp.write(d[:offset])
fp.finish()
fp.connection.stream.close()
return
return True
fp.write(d)
sofar += len(d)
skip = 0
@ -447,9 +447,10 @@ class Response:
vals.reverse()
actions = self.ready_actions(self.length(), self.actions)
actions.reverse()
self.write_values(fp, vals, actions[:])
disconnect = self.write_values(fp, vals, actions[:])
duration = time.time() - started
return dict(
disconnect = disconnect,
started = started,
duration = duration,
actions = actions,

View File

@ -4,31 +4,6 @@ import rparse
class AnchorError(Exception): pass
class Sponge:
def __getattr__(self, x):
return Sponge()
def __call__(self, *args, **kwargs):
pass
class DummyRequest:
connection = Sponge()
def __init__(self):
self.buf = []
def write(self, d, callback=None):
self.buf.append(str(d))
if callback:
callback()
def getvalue(self):
return "".join(self.buf)
def finish(self):
return
def parse_anchor_spec(s, settings):
"""
For now, this is very simple, and you can't have an '=' in your regular

View File

@ -1,7 +1,6 @@
from libpathod import pathod
from tornado import httpserver
class TestApplication:
class _TestApplication:
def test_anchors(self):
a = pathod.PathodApp(staticdir=None)
a.add_anchor("/foo", "200")
@ -30,6 +29,7 @@ class TestApplication:
assert not a.log_by_id(0)
def test_make_server():
app = pathod.PathodApp()
assert pathod.make_server(app, 0, "127.0.0.1", None)
class TestPathod:
def test_instantiation(self):
pathod.Pathod(("127.0.0.1", 0))

View File

@ -1,4 +1,4 @@
import os
import os, cStringIO
from libpathod import rparse, utils
import tutils
@ -131,7 +131,7 @@ class TestMisc:
assert r.msg.val == "Unknown code"
def test_internal_response(self):
d = utils.DummyRequest()
d = cStringIO.StringIO()
s = rparse.InternalResponse(400, "foo")
s.serve(d)
@ -245,7 +245,7 @@ class TestResponse:
def test_write_values_disconnects(self):
r = self.dummy_response()
s = utils.DummyRequest()
s = cStringIO.StringIO()
tst = "foo"*100
r.write_values(s, [tst], [(0, "disconnect")], blocksize=5)
assert not s.getvalue()
@ -254,7 +254,7 @@ class TestResponse:
tst = "foo"*1025
r = rparse.parse({}, "400'msg'")
s = utils.DummyRequest()
s = cStringIO.StringIO()
r.write_values(s, [tst], [])
assert s.getvalue() == tst
@ -263,29 +263,29 @@ class TestResponse:
r = rparse.parse({}, "400'msg'")
for i in range(2, 10):
s = utils.DummyRequest()
s = cStringIO.StringIO()
r.write_values(s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i)
assert s.getvalue() == tst
for i in range(2, 10):
s = utils.DummyRequest()
s = cStringIO.StringIO()
r.write_values(s, [tst], [(1, "pause", 0)], blocksize=i)
assert s.getvalue() == tst
tst = ["".join(str(i) for i in range(10))]*5
for i in range(2, 10):
s = utils.DummyRequest()
s = cStringIO.StringIO()
r.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i)
assert s.getvalue() == "".join(tst)
def test_render(self):
s = utils.DummyRequest()
s = cStringIO.StringIO()
r = rparse.parse({}, "400'msg'")
assert r.serve(s)
def test_length(self):
def testlen(x):
s = utils.DummyRequest()
s = cStringIO.StringIO()
x.serve(s)
assert x.length() == len(s.getvalue())
testlen(rparse.parse({}, "400'msg'"))