Consolidate the various HTTP header dictionary classes into one,
which includes better handling of headers with repeated values (e.g. Set-Cookie)
This commit is contained in:
parent
e98735bf4b
commit
a7dc5bc4b4
|
@ -24,6 +24,7 @@ import errno
|
|||
import escape
|
||||
import functools
|
||||
import httplib
|
||||
import httputil
|
||||
import ioloop
|
||||
import logging
|
||||
import pycurl
|
||||
|
@ -60,7 +61,7 @@ class HTTPClient(object):
|
|||
if not isinstance(request, HTTPRequest):
|
||||
request = HTTPRequest(url=request, **kwargs)
|
||||
buffer = cStringIO.StringIO()
|
||||
headers = {}
|
||||
headers = httputil.HTTPHeaders()
|
||||
try:
|
||||
_curl_setup_request(self._curl, request, buffer, headers)
|
||||
self._curl.perform()
|
||||
|
@ -254,7 +255,7 @@ class AsyncHTTPClient(object):
|
|||
curl = self._free_list.pop()
|
||||
(request, callback) = self._requests.popleft()
|
||||
curl.info = {
|
||||
"headers": {},
|
||||
"headers": httputil.HTTPHeaders(),
|
||||
"buffer": cStringIO.StringIO(),
|
||||
"request": request,
|
||||
"callback": callback,
|
||||
|
@ -462,7 +463,7 @@ class AsyncHTTPClient2(object):
|
|||
curl = self._free_list.pop()
|
||||
(request, callback) = self._requests.popleft()
|
||||
curl.info = {
|
||||
"headers": {},
|
||||
"headers": httputil.HTTPHeaders(),
|
||||
"buffer": cStringIO.StringIO(),
|
||||
"request": request,
|
||||
"callback": callback,
|
||||
|
@ -505,7 +506,7 @@ class AsyncHTTPClient2(object):
|
|||
|
||||
|
||||
class HTTPRequest(object):
|
||||
def __init__(self, url, method="GET", headers={}, body=None,
|
||||
def __init__(self, url, method="GET", headers=None, body=None,
|
||||
auth_username=None, auth_password=None,
|
||||
connect_timeout=20.0, request_timeout=20.0,
|
||||
if_modified_since=None, follow_redirects=True,
|
||||
|
@ -513,6 +514,8 @@ class HTTPRequest(object):
|
|||
network_interface=None, streaming_callback=None,
|
||||
header_callback=None, prepare_curl_callback=None,
|
||||
allow_nonstandard_methods=False):
|
||||
if headers is None:
|
||||
headers = httputil.HTTPHeaders()
|
||||
if if_modified_since:
|
||||
timestamp = calendar.timegm(if_modified_since.utctimetuple())
|
||||
headers["If-Modified-Since"] = email.utils.formatdate(
|
||||
|
@ -618,8 +621,13 @@ def _curl_create(max_simultaneous_connections=None):
|
|||
|
||||
def _curl_setup_request(curl, request, buffer, headers):
|
||||
curl.setopt(pycurl.URL, request.url)
|
||||
curl.setopt(pycurl.HTTPHEADER,
|
||||
[_utf8("%s: %s" % i) for i in request.headers.iteritems()])
|
||||
# Request headers may be either a regular dict or HTTPHeaders object
|
||||
if isinstance(request.headers, httputil.HTTPHeaders):
|
||||
curl.setopt(pycurl.HTTPHEADER,
|
||||
[_utf8("%s: %s" % i) for i in request.headers.get_all()])
|
||||
else:
|
||||
curl.setopt(pycurl.HTTPHEADER,
|
||||
[_utf8("%s: %s" % i) for i in request.headers.iteritems()])
|
||||
if request.header_callback:
|
||||
curl.setopt(pycurl.HEADERFUNCTION, request.header_callback)
|
||||
else:
|
||||
|
@ -695,17 +703,7 @@ def _curl_header_callback(headers, header_line):
|
|||
return
|
||||
if header_line == "\r\n":
|
||||
return
|
||||
parts = header_line.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
logging.warning("Invalid HTTP response header line %r", header_line)
|
||||
return
|
||||
name = parts[0].strip()
|
||||
value = parts[1].strip()
|
||||
if name in headers:
|
||||
headers[name] = headers[name] + ',' + value
|
||||
else:
|
||||
headers[name] = value
|
||||
|
||||
headers.parse_line(header_line)
|
||||
|
||||
def _curl_debug(debug_type, debug_msg):
|
||||
debug_types = ('I', '<', '>', '<', '>')
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
import cgi
|
||||
import errno
|
||||
import functools
|
||||
import httputil
|
||||
import ioloop
|
||||
import iostream
|
||||
import logging
|
||||
|
@ -277,7 +278,7 @@ class HTTPConnection(object):
|
|||
method, uri, version = start_line.split(" ")
|
||||
if not version.startswith("HTTP/"):
|
||||
raise Exception("Malformed HTTP version in HTTP Request-Line")
|
||||
headers = HTTPHeaders.parse(data[eol:])
|
||||
headers = httputil.HTTPHeaders.parse(data[eol:])
|
||||
self._request = HTTPRequest(
|
||||
connection=self, method=method, uri=uri, version=version,
|
||||
headers=headers, remote_ip=self.address[0])
|
||||
|
@ -332,7 +333,7 @@ class HTTPConnection(object):
|
|||
if eoh == -1:
|
||||
logging.warning("multipart/form-data missing headers")
|
||||
continue
|
||||
headers = HTTPHeaders.parse(part[:eoh])
|
||||
headers = httputil.HTTPHeaders.parse(part[:eoh])
|
||||
name_header = headers.get("Content-Disposition", "")
|
||||
if not name_header.startswith("form-data;") or \
|
||||
not part.endswith("\r\n"):
|
||||
|
@ -380,7 +381,7 @@ class HTTPRequest(object):
|
|||
self.method = method
|
||||
self.uri = uri
|
||||
self.version = version
|
||||
self.headers = headers or HTTPHeaders()
|
||||
self.headers = headers or httputil.HTTPHeaders()
|
||||
self.body = body or ""
|
||||
if connection and connection.xheaders:
|
||||
# Squid uses X-Forwarded-For, others use X-Real-Ip
|
||||
|
@ -437,23 +438,3 @@ class HTTPRequest(object):
|
|||
return "%s(%s, headers=%s)" % (
|
||||
self.__class__.__name__, args, dict(self.headers))
|
||||
|
||||
|
||||
class HTTPHeaders(dict):
|
||||
"""A dictionary that maintains Http-Header-Case for all keys."""
|
||||
def __setitem__(self, name, value):
|
||||
dict.__setitem__(self, self._normalize_name(name), value)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return dict.__getitem__(self, self._normalize_name(name))
|
||||
|
||||
def _normalize_name(self, name):
|
||||
return "-".join([w.capitalize() for w in name.split("-")])
|
||||
|
||||
@classmethod
|
||||
def parse(cls, headers_string):
|
||||
headers = cls()
|
||||
for line in headers_string.splitlines():
|
||||
if line:
|
||||
name, value = line.split(":", 1)
|
||||
headers[name] = value.strip()
|
||||
return headers
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
#!/usr/bin/env python
|
||||
#
|
||||
# Copyright 2009 Facebook
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""HTTP utility code shared by clients and servers."""
|
||||
|
||||
class HTTPHeaders(dict):
|
||||
"""A dictionary that maintains Http-Header-Case for all keys.
|
||||
|
||||
Supports multiple values per key via a pair of new methods,
|
||||
add() and get_list(). The regular dictionary interface returns a single
|
||||
value per key, with multiple values joined by a comma.
|
||||
|
||||
>>> h = HTTPHeaders({"content-type": "text/html"})
|
||||
>>> h.keys()
|
||||
['Content-Type']
|
||||
>>> h["Content-Type"]
|
||||
'text/html'
|
||||
|
||||
>>> h.add("Set-Cookie", "A=B")
|
||||
>>> h.add("Set-Cookie", "C=D")
|
||||
>>> h["set-cookie"]
|
||||
'A=B,C=D'
|
||||
>>> h.get_list("set-cookie")
|
||||
['A=B', 'C=D']
|
||||
|
||||
>>> for (k,v) in sorted(h.get_all()):
|
||||
... print '%s: %s' % (k,v)
|
||||
...
|
||||
Content-Type: text/html
|
||||
Set-Cookie: A=B
|
||||
Set-Cookie: C=D
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Don't pass args or kwargs to dict.__init__, as it will bypass
|
||||
# our __setitem__
|
||||
dict.__init__(self)
|
||||
self._as_list = {}
|
||||
self.update(*args, **kwargs)
|
||||
|
||||
# new public methods
|
||||
|
||||
def add(self, name, value):
|
||||
"""Adds a new value for the given key."""
|
||||
norm_name = HTTPHeaders._normalize_name(name)
|
||||
if norm_name in self:
|
||||
# bypass our override of __setitem__ since it modifies _as_list
|
||||
dict.__setitem__(self, norm_name, self[norm_name] + ',' + value)
|
||||
self._as_list[norm_name].append(value)
|
||||
else:
|
||||
self[norm_name] = value
|
||||
|
||||
def get_list(self, name):
|
||||
"""Returns all values for the given header as a list."""
|
||||
norm_name = HTTPHeaders._normalize_name(name)
|
||||
return self._as_list.get(norm_name, [])
|
||||
|
||||
def get_all(self):
|
||||
"""Returns an iterable of all (name, value) pairs.
|
||||
|
||||
If a header has multiple values, multiple pairs will be
|
||||
returned with the same name.
|
||||
"""
|
||||
for name, list in self._as_list.iteritems():
|
||||
for value in list:
|
||||
yield (name, value)
|
||||
|
||||
def parse_line(self, line):
|
||||
"""Updates the dictionary with a single header line.
|
||||
|
||||
>>> h = HTTPHeaders()
|
||||
>>> h.parse_line("Content-Type: text/html")
|
||||
>>> h.get('content-type')
|
||||
'text/html'
|
||||
"""
|
||||
name, value = line.split(":", 1)
|
||||
self.add(name, value.strip())
|
||||
|
||||
@classmethod
|
||||
def parse(cls, headers):
|
||||
"""Returns a dictionary from HTTP header text.
|
||||
|
||||
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
|
||||
>>> sorted(h.iteritems())
|
||||
[('Content-Length', '42'), ('Content-Type', 'text/html')]
|
||||
"""
|
||||
h = cls()
|
||||
for line in headers.splitlines():
|
||||
if line:
|
||||
h.parse_line(line)
|
||||
return h
|
||||
|
||||
# dict implementation overrides
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
norm_name = HTTPHeaders._normalize_name(name)
|
||||
dict.__setitem__(self, norm_name, value)
|
||||
self._as_list[norm_name] = [value]
|
||||
|
||||
def __getitem__(self, name):
|
||||
return dict.__getitem__(self, HTTPHeaders._normalize_name(name))
|
||||
|
||||
def __delitem__(self, name):
|
||||
norm_name = HTTPHeaders._normalize_name(name)
|
||||
dict.__delitem__(self, norm_name)
|
||||
del self._as_list[norm_name]
|
||||
|
||||
def get(self, name, default=None):
|
||||
return dict.get(self, HTTPHeaders._normalize_name(name), default)
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
# dict.update bypasses our __setitem__
|
||||
for k, v in dict(*args, **kwargs).iteritems():
|
||||
self[k] = v
|
||||
|
||||
@staticmethod
|
||||
def _normalize_name(name):
|
||||
"""Converts a name to Http-Header-Case.
|
||||
|
||||
>>> HTTPHeaders._normalize_name("coNtent-TYPE")
|
||||
'Content-Type'
|
||||
"""
|
||||
return "-".join([w.capitalize() for w in name.split("-")])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
|
@ -54,6 +54,7 @@ import cgi
|
|||
import cStringIO
|
||||
import escape
|
||||
import httplib
|
||||
import httputil
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
@ -100,7 +101,7 @@ class HTTPRequest(object):
|
|||
values = [v for v in values if v]
|
||||
if values: self.arguments[name] = values
|
||||
self.version = "HTTP/1.1"
|
||||
self.headers = HTTPHeaders()
|
||||
self.headers = httputil.HTTPHeaders()
|
||||
if environ.get("CONTENT_TYPE"):
|
||||
self.headers["Content-Type"] = environ["CONTENT_TYPE"]
|
||||
if environ.get("CONTENT_LENGTH"):
|
||||
|
@ -164,7 +165,7 @@ class HTTPRequest(object):
|
|||
if eoh == -1:
|
||||
logging.warning("multipart/form-data missing headers")
|
||||
continue
|
||||
headers = HTTPHeaders.parse(part[:eoh])
|
||||
headers = httputil.HTTPHeaders.parse(part[:eoh])
|
||||
name_header = headers.get("Content-Disposition", "")
|
||||
if not name_header.startswith("form-data;") or \
|
||||
not part.endswith("\r\n"):
|
||||
|
@ -293,23 +294,3 @@ class WSGIContainer(object):
|
|||
request.remote_ip + ")"
|
||||
log_method("%d %s %.2fms", status_code, summary, request_time)
|
||||
|
||||
|
||||
class HTTPHeaders(dict):
|
||||
"""A dictionary that maintains Http-Header-Case for all keys."""
|
||||
def __setitem__(self, name, value):
|
||||
dict.__setitem__(self, self._normalize_name(name), value)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return dict.__getitem__(self, self._normalize_name(name))
|
||||
|
||||
def _normalize_name(self, name):
|
||||
return "-".join([w.capitalize() for w in name.split("-")])
|
||||
|
||||
@classmethod
|
||||
def parse(cls, headers_string):
|
||||
headers = cls()
|
||||
for line in headers_string.splitlines():
|
||||
if line:
|
||||
name, value = line.split(":", 1)
|
||||
headers[name] = value.strip()
|
||||
return headers
|
||||
|
|
Loading…
Reference in New Issue