Consolidate the various HTTP header dictionary classes into one,

which includes better handling of headers with repeated values
(e.g. Set-Cookie)
This commit is contained in:
Ben Darnell 2010-07-08 17:06:03 -07:00
parent e98735bf4b
commit a7dc5bc4b4
4 changed files with 162 additions and 62 deletions

View File

@ -24,6 +24,7 @@ import errno
import escape
import functools
import httplib
import httputil
import ioloop
import logging
import pycurl
@ -60,7 +61,7 @@ class HTTPClient(object):
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
buffer = cStringIO.StringIO()
headers = {}
headers = httputil.HTTPHeaders()
try:
_curl_setup_request(self._curl, request, buffer, headers)
self._curl.perform()
@ -254,7 +255,7 @@ class AsyncHTTPClient(object):
curl = self._free_list.pop()
(request, callback) = self._requests.popleft()
curl.info = {
"headers": {},
"headers": httputil.HTTPHeaders(),
"buffer": cStringIO.StringIO(),
"request": request,
"callback": callback,
@ -462,7 +463,7 @@ class AsyncHTTPClient2(object):
curl = self._free_list.pop()
(request, callback) = self._requests.popleft()
curl.info = {
"headers": {},
"headers": httputil.HTTPHeaders(),
"buffer": cStringIO.StringIO(),
"request": request,
"callback": callback,
@ -505,7 +506,7 @@ class AsyncHTTPClient2(object):
class HTTPRequest(object):
def __init__(self, url, method="GET", headers={}, body=None,
def __init__(self, url, method="GET", headers=None, body=None,
auth_username=None, auth_password=None,
connect_timeout=20.0, request_timeout=20.0,
if_modified_since=None, follow_redirects=True,
@ -513,6 +514,8 @@ class HTTPRequest(object):
network_interface=None, streaming_callback=None,
header_callback=None, prepare_curl_callback=None,
allow_nonstandard_methods=False):
if headers is None:
headers = httputil.HTTPHeaders()
if if_modified_since:
timestamp = calendar.timegm(if_modified_since.utctimetuple())
headers["If-Modified-Since"] = email.utils.formatdate(
@ -618,8 +621,13 @@ def _curl_create(max_simultaneous_connections=None):
def _curl_setup_request(curl, request, buffer, headers):
curl.setopt(pycurl.URL, request.url)
curl.setopt(pycurl.HTTPHEADER,
[_utf8("%s: %s" % i) for i in request.headers.iteritems()])
# Request headers may be either a regular dict or HTTPHeaders object
if isinstance(request.headers, httputil.HTTPHeaders):
curl.setopt(pycurl.HTTPHEADER,
[_utf8("%s: %s" % i) for i in request.headers.get_all()])
else:
curl.setopt(pycurl.HTTPHEADER,
[_utf8("%s: %s" % i) for i in request.headers.iteritems()])
if request.header_callback:
curl.setopt(pycurl.HEADERFUNCTION, request.header_callback)
else:
@ -695,17 +703,7 @@ def _curl_header_callback(headers, header_line):
return
if header_line == "\r\n":
return
parts = header_line.split(":", 1)
if len(parts) != 2:
logging.warning("Invalid HTTP response header line %r", header_line)
return
name = parts[0].strip()
value = parts[1].strip()
if name in headers:
headers[name] = headers[name] + ',' + value
else:
headers[name] = value
headers.parse_line(header_line)
def _curl_debug(debug_type, debug_msg):
debug_types = ('I', '<', '>', '<', '>')

View File

@ -19,6 +19,7 @@
import cgi
import errno
import functools
import httputil
import ioloop
import iostream
import logging
@ -277,7 +278,7 @@ class HTTPConnection(object):
method, uri, version = start_line.split(" ")
if not version.startswith("HTTP/"):
raise Exception("Malformed HTTP version in HTTP Request-Line")
headers = HTTPHeaders.parse(data[eol:])
headers = httputil.HTTPHeaders.parse(data[eol:])
self._request = HTTPRequest(
connection=self, method=method, uri=uri, version=version,
headers=headers, remote_ip=self.address[0])
@ -332,7 +333,7 @@ class HTTPConnection(object):
if eoh == -1:
logging.warning("multipart/form-data missing headers")
continue
headers = HTTPHeaders.parse(part[:eoh])
headers = httputil.HTTPHeaders.parse(part[:eoh])
name_header = headers.get("Content-Disposition", "")
if not name_header.startswith("form-data;") or \
not part.endswith("\r\n"):
@ -380,7 +381,7 @@ class HTTPRequest(object):
self.method = method
self.uri = uri
self.version = version
self.headers = headers or HTTPHeaders()
self.headers = headers or httputil.HTTPHeaders()
self.body = body or ""
if connection and connection.xheaders:
# Squid uses X-Forwarded-For, others use X-Real-Ip
@ -437,23 +438,3 @@ class HTTPRequest(object):
return "%s(%s, headers=%s)" % (
self.__class__.__name__, args, dict(self.headers))
class HTTPHeaders(dict):
"""A dictionary that maintains Http-Header-Case for all keys."""
def __setitem__(self, name, value):
dict.__setitem__(self, self._normalize_name(name), value)
def __getitem__(self, name):
return dict.__getitem__(self, self._normalize_name(name))
def _normalize_name(self, name):
return "-".join([w.capitalize() for w in name.split("-")])
@classmethod
def parse(cls, headers_string):
headers = cls()
for line in headers_string.splitlines():
if line:
name, value = line.split(":", 1)
headers[name] = value.strip()
return headers

140
tornado/httputil.py Executable file
View File

@ -0,0 +1,140 @@
#!/usr/bin/env python
#
# Copyright 2009 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""HTTP utility code shared by clients and servers."""
class HTTPHeaders(dict):
"""A dictionary that maintains Http-Header-Case for all keys.
Supports multiple values per key via a pair of new methods,
add() and get_list(). The regular dictionary interface returns a single
value per key, with multiple values joined by a comma.
>>> h = HTTPHeaders({"content-type": "text/html"})
>>> h.keys()
['Content-Type']
>>> h["Content-Type"]
'text/html'
>>> h.add("Set-Cookie", "A=B")
>>> h.add("Set-Cookie", "C=D")
>>> h["set-cookie"]
'A=B,C=D'
>>> h.get_list("set-cookie")
['A=B', 'C=D']
>>> for (k,v) in sorted(h.get_all()):
... print '%s: %s' % (k,v)
...
Content-Type: text/html
Set-Cookie: A=B
Set-Cookie: C=D
"""
def __init__(self, *args, **kwargs):
# Don't pass args or kwargs to dict.__init__, as it will bypass
# our __setitem__
dict.__init__(self)
self._as_list = {}
self.update(*args, **kwargs)
# new public methods
def add(self, name, value):
"""Adds a new value for the given key."""
norm_name = HTTPHeaders._normalize_name(name)
if norm_name in self:
# bypass our override of __setitem__ since it modifies _as_list
dict.__setitem__(self, norm_name, self[norm_name] + ',' + value)
self._as_list[norm_name].append(value)
else:
self[norm_name] = value
def get_list(self, name):
"""Returns all values for the given header as a list."""
norm_name = HTTPHeaders._normalize_name(name)
return self._as_list.get(norm_name, [])
def get_all(self):
"""Returns an iterable of all (name, value) pairs.
If a header has multiple values, multiple pairs will be
returned with the same name.
"""
for name, list in self._as_list.iteritems():
for value in list:
yield (name, value)
def parse_line(self, line):
"""Updates the dictionary with a single header line.
>>> h = HTTPHeaders()
>>> h.parse_line("Content-Type: text/html")
>>> h.get('content-type')
'text/html'
"""
name, value = line.split(":", 1)
self.add(name, value.strip())
@classmethod
def parse(cls, headers):
"""Returns a dictionary from HTTP header text.
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
>>> sorted(h.iteritems())
[('Content-Length', '42'), ('Content-Type', 'text/html')]
"""
h = cls()
for line in headers.splitlines():
if line:
h.parse_line(line)
return h
# dict implementation overrides
def __setitem__(self, name, value):
norm_name = HTTPHeaders._normalize_name(name)
dict.__setitem__(self, norm_name, value)
self._as_list[norm_name] = [value]
def __getitem__(self, name):
return dict.__getitem__(self, HTTPHeaders._normalize_name(name))
def __delitem__(self, name):
norm_name = HTTPHeaders._normalize_name(name)
dict.__delitem__(self, norm_name)
del self._as_list[norm_name]
def get(self, name, default=None):
return dict.get(self, HTTPHeaders._normalize_name(name), default)
def update(self, *args, **kwargs):
# dict.update bypasses our __setitem__
for k, v in dict(*args, **kwargs).iteritems():
self[k] = v
@staticmethod
def _normalize_name(name):
"""Converts a name to Http-Header-Case.
>>> HTTPHeaders._normalize_name("coNtent-TYPE")
'Content-Type'
"""
return "-".join([w.capitalize() for w in name.split("-")])
if __name__ == "__main__":
import doctest
doctest.testmod()

View File

@ -54,6 +54,7 @@ import cgi
import cStringIO
import escape
import httplib
import httputil
import logging
import sys
import time
@ -100,7 +101,7 @@ class HTTPRequest(object):
values = [v for v in values if v]
if values: self.arguments[name] = values
self.version = "HTTP/1.1"
self.headers = HTTPHeaders()
self.headers = httputil.HTTPHeaders()
if environ.get("CONTENT_TYPE"):
self.headers["Content-Type"] = environ["CONTENT_TYPE"]
if environ.get("CONTENT_LENGTH"):
@ -164,7 +165,7 @@ class HTTPRequest(object):
if eoh == -1:
logging.warning("multipart/form-data missing headers")
continue
headers = HTTPHeaders.parse(part[:eoh])
headers = httputil.HTTPHeaders.parse(part[:eoh])
name_header = headers.get("Content-Disposition", "")
if not name_header.startswith("form-data;") or \
not part.endswith("\r\n"):
@ -293,23 +294,3 @@ class WSGIContainer(object):
request.remote_ip + ")"
log_method("%d %s %.2fms", status_code, summary, request_time)
class HTTPHeaders(dict):
"""A dictionary that maintains Http-Header-Case for all keys."""
def __setitem__(self, name, value):
dict.__setitem__(self, self._normalize_name(name), value)
def __getitem__(self, name):
return dict.__getitem__(self, self._normalize_name(name))
def _normalize_name(self, name):
return "-".join([w.capitalize() for w in name.split("-")])
@classmethod
def parse(cls, headers_string):
headers = cls()
for line in headers_string.splitlines():
if line:
name, value = line.split(":", 1)
headers[name] = value.strip()
return headers