Add a Host header to pathoc requests by default.

This commit is contained in:
Aldo Cortesi 2012-07-24 21:38:28 +12:00
parent 9502eeadaa
commit 94b491bb27
3 changed files with 30 additions and 21 deletions

View File

@ -20,7 +20,7 @@ class Pathoc(tcp.TCPClient):
tcp.TCPClient.__init__(self, host, port) tcp.TCPClient.__init__(self, host, port)
self.settings = dict( self.settings = dict(
staticdir = os.getcwd(), staticdir = os.getcwd(),
unconstrained_file_access = True unconstrained_file_access = True,
) )
def request(self, spec): def request(self, spec):
@ -31,7 +31,7 @@ class Pathoc(tcp.TCPClient):
rparse.FileAccessDenied. rparse.FileAccessDenied.
""" """
r = rparse.parse_request(self.settings, spec) r = rparse.parse_request(self.settings, spec)
ret = r.serve(self.wfile) ret = r.serve(self.wfile, None, self.host)
self.wfile.flush() self.wfile.flush()
return http.read_response(self.rfile, r.method, None) return http.read_response(self.rfile, r.method, None)
@ -43,7 +43,7 @@ class Pathoc(tcp.TCPClient):
for i in reqs: for i in reqs:
try: try:
r = rparse.parse_request(self.settings, i) r = rparse.parse_request(self.settings, i)
req = r.serve(self.wfile) req = r.serve(self.wfile, None, self.host)
if reqdump: if reqdump:
print >> fp, "\n>>", req["method"], repr(req["path"]) print >> fp, "\n>>", req["method"], repr(req["path"])
for a in req["actions"]: for a in req["actions"]:

View File

@ -541,7 +541,7 @@ class Message:
l += len(i[2]) l += len(i[2])
return l return l
def serve(self, fp, check, is_request): def serve(self, fp, check, request_host):
""" """
fp: The file pointer to write to. fp: The file pointer to write to.
@ -550,9 +550,11 @@ class Message:
otherwise the return is treated as an error message to be sent to otherwise the return is treated as an error message to be sent to
the client, and service stops. the client, and service stops.
is_request: Is this a request? If False, we assume it's a response. request_host: If this a request, this is the connecting host. If
Used to decide what standard modifications to make if raw is not None, we assume it's a response. Used to decide what standard
set. modifications to make if raw is not set.
Calling this function may modify the object.
""" """
started = time.time() started = time.time()
if not self.raw: if not self.raw:
@ -563,8 +565,15 @@ class Message:
LiteralGenerator(str(len(self.body))), LiteralGenerator(str(len(self.body))),
) )
) )
if is_request: if request_host:
pass if not utils.get_header("Host", self.headers):
self.headers.append(
(
LiteralGenerator("Host"),
LiteralGenerator(request_host)
)
)
else: else:
if not utils.get_header("Date", self.headers): if not utils.get_header("Date", self.headers):
self.headers.append( self.headers.append(
@ -706,8 +715,8 @@ class CraftedRequest(Request):
for i in tokens: for i in tokens:
i.accept(settings, self) i.accept(settings, self)
def serve(self, fp, check=None): def serve(self, fp, check, host):
d = Request.serve(self, fp, check, True) d = Request.serve(self, fp, check, host)
d["spec"] = self.spec d["spec"] = self.spec
return d return d
@ -719,8 +728,8 @@ class CraftedResponse(Response):
for i in tokens: for i in tokens:
i.accept(settings, self) i.accept(settings, self)
def serve(self, fp, check=None): def serve(self, fp, check):
d = Response.serve(self, fp, check, False) d = Response.serve(self, fp, check, None)
d["spec"] = self.spec d["spec"] = self.spec
return d return d
@ -738,7 +747,7 @@ class PathodErrorResponse(Response):
] ]
def serve(self, fp, check=None): def serve(self, fp, check=None):
d = Response.serve(self, fp, check, False) d = Response.serve(self, fp, check, None)
d["internal"] = True d["internal"] = True
return d return d

View File

@ -210,7 +210,7 @@ class TestInject:
def test_serve(self): def test_serve(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = rparse.parse_response({}, "400:i0,'foo'") r = rparse.parse_response({}, "400:i0,'foo'")
assert r.serve(s) assert r.serve(s, None)
class TestShortcuts: class TestShortcuts:
@ -262,7 +262,7 @@ class TestParseRequest:
def test_render(self): def test_render(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = rparse.parse_request({}, "GET:'/foo'") r = rparse.parse_request({}, "GET:'/foo'")
assert r.serve(s) assert r.serve(s, None, "foo.com")
def test_str(self): def test_str(self):
r = rparse.parse_request({}, 'GET:"/foo"') r = rparse.parse_request({}, 'GET:"/foo"')
@ -438,19 +438,19 @@ class TestResponse:
def test_render(self): def test_render(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = rparse.parse_response({}, "400'msg'") r = rparse.parse_response({}, "400'msg'")
assert r.serve(s) assert r.serve(s, None)
def test_raw(self): def test_raw(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = rparse.parse_response({}, "400:b'foo'") r = rparse.parse_response({}, "400:b'foo'")
r.serve(s) r.serve(s, None)
v = s.getvalue() v = s.getvalue()
assert "Content-Length" in v assert "Content-Length" in v
assert "Date" in v assert "Date" in v
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = rparse.parse_response({}, "400:b'foo':r") r = rparse.parse_response({}, "400:b'foo':r")
r.serve(s) r.serve(s, None)
v = s.getvalue() v = s.getvalue()
assert not "Content-Length" in v assert not "Content-Length" in v
assert not "Date" in v assert not "Date" in v
@ -458,7 +458,7 @@ class TestResponse:
def test_length(self): def test_length(self):
def testlen(x): def testlen(x):
s = cStringIO.StringIO() s = cStringIO.StringIO()
x.serve(s) x.serve(s, None)
assert x.length() == len(s.getvalue()) assert x.length() == len(s.getvalue())
testlen(rparse.parse_response({}, "400'msg'")) testlen(rparse.parse_response({}, "400'msg'"))
testlen(rparse.parse_response({}, "400'msg':h'foo'='bar'")) testlen(rparse.parse_response({}, "400'msg':h'foo'='bar'"))
@ -467,7 +467,7 @@ class TestResponse:
def test_effective_length(self): def test_effective_length(self):
def testlen(x, actions): def testlen(x, actions):
s = cStringIO.StringIO() s = cStringIO.StringIO()
x.serve(s) x.serve(s, None)
assert x.effective_length(actions) == len(s.getvalue()) assert x.effective_length(actions) == len(s.getvalue())
actions = [ actions = [