Consolidate the various HTTP header dictionary classes into one,
which includes better handling of headers with repeated values (e.g. Set-Cookie)
This commit is contained in:
parent
e98735bf4b
commit
a7dc5bc4b4
|
@ -24,6 +24,7 @@ import errno
|
||||||
import escape
|
import escape
|
||||||
import functools
|
import functools
|
||||||
import httplib
|
import httplib
|
||||||
|
import httputil
|
||||||
import ioloop
|
import ioloop
|
||||||
import logging
|
import logging
|
||||||
import pycurl
|
import pycurl
|
||||||
|
@ -60,7 +61,7 @@ class HTTPClient(object):
|
||||||
if not isinstance(request, HTTPRequest):
|
if not isinstance(request, HTTPRequest):
|
||||||
request = HTTPRequest(url=request, **kwargs)
|
request = HTTPRequest(url=request, **kwargs)
|
||||||
buffer = cStringIO.StringIO()
|
buffer = cStringIO.StringIO()
|
||||||
headers = {}
|
headers = httputil.HTTPHeaders()
|
||||||
try:
|
try:
|
||||||
_curl_setup_request(self._curl, request, buffer, headers)
|
_curl_setup_request(self._curl, request, buffer, headers)
|
||||||
self._curl.perform()
|
self._curl.perform()
|
||||||
|
@ -254,7 +255,7 @@ class AsyncHTTPClient(object):
|
||||||
curl = self._free_list.pop()
|
curl = self._free_list.pop()
|
||||||
(request, callback) = self._requests.popleft()
|
(request, callback) = self._requests.popleft()
|
||||||
curl.info = {
|
curl.info = {
|
||||||
"headers": {},
|
"headers": httputil.HTTPHeaders(),
|
||||||
"buffer": cStringIO.StringIO(),
|
"buffer": cStringIO.StringIO(),
|
||||||
"request": request,
|
"request": request,
|
||||||
"callback": callback,
|
"callback": callback,
|
||||||
|
@ -462,7 +463,7 @@ class AsyncHTTPClient2(object):
|
||||||
curl = self._free_list.pop()
|
curl = self._free_list.pop()
|
||||||
(request, callback) = self._requests.popleft()
|
(request, callback) = self._requests.popleft()
|
||||||
curl.info = {
|
curl.info = {
|
||||||
"headers": {},
|
"headers": httputil.HTTPHeaders(),
|
||||||
"buffer": cStringIO.StringIO(),
|
"buffer": cStringIO.StringIO(),
|
||||||
"request": request,
|
"request": request,
|
||||||
"callback": callback,
|
"callback": callback,
|
||||||
|
@ -505,7 +506,7 @@ class AsyncHTTPClient2(object):
|
||||||
|
|
||||||
|
|
||||||
class HTTPRequest(object):
|
class HTTPRequest(object):
|
||||||
def __init__(self, url, method="GET", headers={}, body=None,
|
def __init__(self, url, method="GET", headers=None, body=None,
|
||||||
auth_username=None, auth_password=None,
|
auth_username=None, auth_password=None,
|
||||||
connect_timeout=20.0, request_timeout=20.0,
|
connect_timeout=20.0, request_timeout=20.0,
|
||||||
if_modified_since=None, follow_redirects=True,
|
if_modified_since=None, follow_redirects=True,
|
||||||
|
@ -513,6 +514,8 @@ class HTTPRequest(object):
|
||||||
network_interface=None, streaming_callback=None,
|
network_interface=None, streaming_callback=None,
|
||||||
header_callback=None, prepare_curl_callback=None,
|
header_callback=None, prepare_curl_callback=None,
|
||||||
allow_nonstandard_methods=False):
|
allow_nonstandard_methods=False):
|
||||||
|
if headers is None:
|
||||||
|
headers = httputil.HTTPHeaders()
|
||||||
if if_modified_since:
|
if if_modified_since:
|
||||||
timestamp = calendar.timegm(if_modified_since.utctimetuple())
|
timestamp = calendar.timegm(if_modified_since.utctimetuple())
|
||||||
headers["If-Modified-Since"] = email.utils.formatdate(
|
headers["If-Modified-Since"] = email.utils.formatdate(
|
||||||
|
@ -618,8 +621,13 @@ def _curl_create(max_simultaneous_connections=None):
|
||||||
|
|
||||||
def _curl_setup_request(curl, request, buffer, headers):
|
def _curl_setup_request(curl, request, buffer, headers):
|
||||||
curl.setopt(pycurl.URL, request.url)
|
curl.setopt(pycurl.URL, request.url)
|
||||||
curl.setopt(pycurl.HTTPHEADER,
|
# Request headers may be either a regular dict or HTTPHeaders object
|
||||||
[_utf8("%s: %s" % i) for i in request.headers.iteritems()])
|
if isinstance(request.headers, httputil.HTTPHeaders):
|
||||||
|
curl.setopt(pycurl.HTTPHEADER,
|
||||||
|
[_utf8("%s: %s" % i) for i in request.headers.get_all()])
|
||||||
|
else:
|
||||||
|
curl.setopt(pycurl.HTTPHEADER,
|
||||||
|
[_utf8("%s: %s" % i) for i in request.headers.iteritems()])
|
||||||
if request.header_callback:
|
if request.header_callback:
|
||||||
curl.setopt(pycurl.HEADERFUNCTION, request.header_callback)
|
curl.setopt(pycurl.HEADERFUNCTION, request.header_callback)
|
||||||
else:
|
else:
|
||||||
|
@ -695,17 +703,7 @@ def _curl_header_callback(headers, header_line):
|
||||||
return
|
return
|
||||||
if header_line == "\r\n":
|
if header_line == "\r\n":
|
||||||
return
|
return
|
||||||
parts = header_line.split(":", 1)
|
headers.parse_line(header_line)
|
||||||
if len(parts) != 2:
|
|
||||||
logging.warning("Invalid HTTP response header line %r", header_line)
|
|
||||||
return
|
|
||||||
name = parts[0].strip()
|
|
||||||
value = parts[1].strip()
|
|
||||||
if name in headers:
|
|
||||||
headers[name] = headers[name] + ',' + value
|
|
||||||
else:
|
|
||||||
headers[name] = value
|
|
||||||
|
|
||||||
|
|
||||||
def _curl_debug(debug_type, debug_msg):
|
def _curl_debug(debug_type, debug_msg):
|
||||||
debug_types = ('I', '<', '>', '<', '>')
|
debug_types = ('I', '<', '>', '<', '>')
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
import cgi
|
import cgi
|
||||||
import errno
|
import errno
|
||||||
import functools
|
import functools
|
||||||
|
import httputil
|
||||||
import ioloop
|
import ioloop
|
||||||
import iostream
|
import iostream
|
||||||
import logging
|
import logging
|
||||||
|
@ -277,7 +278,7 @@ class HTTPConnection(object):
|
||||||
method, uri, version = start_line.split(" ")
|
method, uri, version = start_line.split(" ")
|
||||||
if not version.startswith("HTTP/"):
|
if not version.startswith("HTTP/"):
|
||||||
raise Exception("Malformed HTTP version in HTTP Request-Line")
|
raise Exception("Malformed HTTP version in HTTP Request-Line")
|
||||||
headers = HTTPHeaders.parse(data[eol:])
|
headers = httputil.HTTPHeaders.parse(data[eol:])
|
||||||
self._request = HTTPRequest(
|
self._request = HTTPRequest(
|
||||||
connection=self, method=method, uri=uri, version=version,
|
connection=self, method=method, uri=uri, version=version,
|
||||||
headers=headers, remote_ip=self.address[0])
|
headers=headers, remote_ip=self.address[0])
|
||||||
|
@ -332,7 +333,7 @@ class HTTPConnection(object):
|
||||||
if eoh == -1:
|
if eoh == -1:
|
||||||
logging.warning("multipart/form-data missing headers")
|
logging.warning("multipart/form-data missing headers")
|
||||||
continue
|
continue
|
||||||
headers = HTTPHeaders.parse(part[:eoh])
|
headers = httputil.HTTPHeaders.parse(part[:eoh])
|
||||||
name_header = headers.get("Content-Disposition", "")
|
name_header = headers.get("Content-Disposition", "")
|
||||||
if not name_header.startswith("form-data;") or \
|
if not name_header.startswith("form-data;") or \
|
||||||
not part.endswith("\r\n"):
|
not part.endswith("\r\n"):
|
||||||
|
@ -380,7 +381,7 @@ class HTTPRequest(object):
|
||||||
self.method = method
|
self.method = method
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
self.version = version
|
self.version = version
|
||||||
self.headers = headers or HTTPHeaders()
|
self.headers = headers or httputil.HTTPHeaders()
|
||||||
self.body = body or ""
|
self.body = body or ""
|
||||||
if connection and connection.xheaders:
|
if connection and connection.xheaders:
|
||||||
# Squid uses X-Forwarded-For, others use X-Real-Ip
|
# Squid uses X-Forwarded-For, others use X-Real-Ip
|
||||||
|
@ -437,23 +438,3 @@ class HTTPRequest(object):
|
||||||
return "%s(%s, headers=%s)" % (
|
return "%s(%s, headers=%s)" % (
|
||||||
self.__class__.__name__, args, dict(self.headers))
|
self.__class__.__name__, args, dict(self.headers))
|
||||||
|
|
||||||
|
|
||||||
class HTTPHeaders(dict):
|
|
||||||
"""A dictionary that maintains Http-Header-Case for all keys."""
|
|
||||||
def __setitem__(self, name, value):
|
|
||||||
dict.__setitem__(self, self._normalize_name(name), value)
|
|
||||||
|
|
||||||
def __getitem__(self, name):
|
|
||||||
return dict.__getitem__(self, self._normalize_name(name))
|
|
||||||
|
|
||||||
def _normalize_name(self, name):
|
|
||||||
return "-".join([w.capitalize() for w in name.split("-")])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse(cls, headers_string):
|
|
||||||
headers = cls()
|
|
||||||
for line in headers_string.splitlines():
|
|
||||||
if line:
|
|
||||||
name, value = line.split(":", 1)
|
|
||||||
headers[name] = value.strip()
|
|
||||||
return headers
|
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
#
|
||||||
|
# Copyright 2009 Facebook
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""HTTP utility code shared by clients and servers."""
|
||||||
|
|
||||||
|
class HTTPHeaders(dict):
|
||||||
|
"""A dictionary that maintains Http-Header-Case for all keys.
|
||||||
|
|
||||||
|
Supports multiple values per key via a pair of new methods,
|
||||||
|
add() and get_list(). The regular dictionary interface returns a single
|
||||||
|
value per key, with multiple values joined by a comma.
|
||||||
|
|
||||||
|
>>> h = HTTPHeaders({"content-type": "text/html"})
|
||||||
|
>>> h.keys()
|
||||||
|
['Content-Type']
|
||||||
|
>>> h["Content-Type"]
|
||||||
|
'text/html'
|
||||||
|
|
||||||
|
>>> h.add("Set-Cookie", "A=B")
|
||||||
|
>>> h.add("Set-Cookie", "C=D")
|
||||||
|
>>> h["set-cookie"]
|
||||||
|
'A=B,C=D'
|
||||||
|
>>> h.get_list("set-cookie")
|
||||||
|
['A=B', 'C=D']
|
||||||
|
|
||||||
|
>>> for (k,v) in sorted(h.get_all()):
|
||||||
|
... print '%s: %s' % (k,v)
|
||||||
|
...
|
||||||
|
Content-Type: text/html
|
||||||
|
Set-Cookie: A=B
|
||||||
|
Set-Cookie: C=D
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Don't pass args or kwargs to dict.__init__, as it will bypass
|
||||||
|
# our __setitem__
|
||||||
|
dict.__init__(self)
|
||||||
|
self._as_list = {}
|
||||||
|
self.update(*args, **kwargs)
|
||||||
|
|
||||||
|
# new public methods
|
||||||
|
|
||||||
|
def add(self, name, value):
|
||||||
|
"""Adds a new value for the given key."""
|
||||||
|
norm_name = HTTPHeaders._normalize_name(name)
|
||||||
|
if norm_name in self:
|
||||||
|
# bypass our override of __setitem__ since it modifies _as_list
|
||||||
|
dict.__setitem__(self, norm_name, self[norm_name] + ',' + value)
|
||||||
|
self._as_list[norm_name].append(value)
|
||||||
|
else:
|
||||||
|
self[norm_name] = value
|
||||||
|
|
||||||
|
def get_list(self, name):
|
||||||
|
"""Returns all values for the given header as a list."""
|
||||||
|
norm_name = HTTPHeaders._normalize_name(name)
|
||||||
|
return self._as_list.get(norm_name, [])
|
||||||
|
|
||||||
|
def get_all(self):
|
||||||
|
"""Returns an iterable of all (name, value) pairs.
|
||||||
|
|
||||||
|
If a header has multiple values, multiple pairs will be
|
||||||
|
returned with the same name.
|
||||||
|
"""
|
||||||
|
for name, list in self._as_list.iteritems():
|
||||||
|
for value in list:
|
||||||
|
yield (name, value)
|
||||||
|
|
||||||
|
def parse_line(self, line):
|
||||||
|
"""Updates the dictionary with a single header line.
|
||||||
|
|
||||||
|
>>> h = HTTPHeaders()
|
||||||
|
>>> h.parse_line("Content-Type: text/html")
|
||||||
|
>>> h.get('content-type')
|
||||||
|
'text/html'
|
||||||
|
"""
|
||||||
|
name, value = line.split(":", 1)
|
||||||
|
self.add(name, value.strip())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, headers):
|
||||||
|
"""Returns a dictionary from HTTP header text.
|
||||||
|
|
||||||
|
>>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n")
|
||||||
|
>>> sorted(h.iteritems())
|
||||||
|
[('Content-Length', '42'), ('Content-Type', 'text/html')]
|
||||||
|
"""
|
||||||
|
h = cls()
|
||||||
|
for line in headers.splitlines():
|
||||||
|
if line:
|
||||||
|
h.parse_line(line)
|
||||||
|
return h
|
||||||
|
|
||||||
|
# dict implementation overrides
|
||||||
|
|
||||||
|
def __setitem__(self, name, value):
|
||||||
|
norm_name = HTTPHeaders._normalize_name(name)
|
||||||
|
dict.__setitem__(self, norm_name, value)
|
||||||
|
self._as_list[norm_name] = [value]
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
return dict.__getitem__(self, HTTPHeaders._normalize_name(name))
|
||||||
|
|
||||||
|
def __delitem__(self, name):
|
||||||
|
norm_name = HTTPHeaders._normalize_name(name)
|
||||||
|
dict.__delitem__(self, norm_name)
|
||||||
|
del self._as_list[norm_name]
|
||||||
|
|
||||||
|
def get(self, name, default=None):
|
||||||
|
return dict.get(self, HTTPHeaders._normalize_name(name), default)
|
||||||
|
|
||||||
|
def update(self, *args, **kwargs):
|
||||||
|
# dict.update bypasses our __setitem__
|
||||||
|
for k, v in dict(*args, **kwargs).iteritems():
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_name(name):
|
||||||
|
"""Converts a name to Http-Header-Case.
|
||||||
|
|
||||||
|
>>> HTTPHeaders._normalize_name("coNtent-TYPE")
|
||||||
|
'Content-Type'
|
||||||
|
"""
|
||||||
|
return "-".join([w.capitalize() for w in name.split("-")])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import doctest
|
||||||
|
doctest.testmod()
|
|
@ -54,6 +54,7 @@ import cgi
|
||||||
import cStringIO
|
import cStringIO
|
||||||
import escape
|
import escape
|
||||||
import httplib
|
import httplib
|
||||||
|
import httputil
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -100,7 +101,7 @@ class HTTPRequest(object):
|
||||||
values = [v for v in values if v]
|
values = [v for v in values if v]
|
||||||
if values: self.arguments[name] = values
|
if values: self.arguments[name] = values
|
||||||
self.version = "HTTP/1.1"
|
self.version = "HTTP/1.1"
|
||||||
self.headers = HTTPHeaders()
|
self.headers = httputil.HTTPHeaders()
|
||||||
if environ.get("CONTENT_TYPE"):
|
if environ.get("CONTENT_TYPE"):
|
||||||
self.headers["Content-Type"] = environ["CONTENT_TYPE"]
|
self.headers["Content-Type"] = environ["CONTENT_TYPE"]
|
||||||
if environ.get("CONTENT_LENGTH"):
|
if environ.get("CONTENT_LENGTH"):
|
||||||
|
@ -164,7 +165,7 @@ class HTTPRequest(object):
|
||||||
if eoh == -1:
|
if eoh == -1:
|
||||||
logging.warning("multipart/form-data missing headers")
|
logging.warning("multipart/form-data missing headers")
|
||||||
continue
|
continue
|
||||||
headers = HTTPHeaders.parse(part[:eoh])
|
headers = httputil.HTTPHeaders.parse(part[:eoh])
|
||||||
name_header = headers.get("Content-Disposition", "")
|
name_header = headers.get("Content-Disposition", "")
|
||||||
if not name_header.startswith("form-data;") or \
|
if not name_header.startswith("form-data;") or \
|
||||||
not part.endswith("\r\n"):
|
not part.endswith("\r\n"):
|
||||||
|
@ -293,23 +294,3 @@ class WSGIContainer(object):
|
||||||
request.remote_ip + ")"
|
request.remote_ip + ")"
|
||||||
log_method("%d %s %.2fms", status_code, summary, request_time)
|
log_method("%d %s %.2fms", status_code, summary, request_time)
|
||||||
|
|
||||||
|
|
||||||
class HTTPHeaders(dict):
|
|
||||||
"""A dictionary that maintains Http-Header-Case for all keys."""
|
|
||||||
def __setitem__(self, name, value):
|
|
||||||
dict.__setitem__(self, self._normalize_name(name), value)
|
|
||||||
|
|
||||||
def __getitem__(self, name):
|
|
||||||
return dict.__getitem__(self, self._normalize_name(name))
|
|
||||||
|
|
||||||
def _normalize_name(self, name):
|
|
||||||
return "-".join([w.capitalize() for w in name.split("-")])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def parse(cls, headers_string):
|
|
||||||
headers = cls()
|
|
||||||
for line in headers_string.splitlines():
|
|
||||||
if line:
|
|
||||||
name, value = line.split(":", 1)
|
|
||||||
headers[name] = value.strip()
|
|
||||||
return headers
|
|
||||||
|
|
Loading…
Reference in New Issue