From a7dc5bc4b463c90c41355c9d8d54e82fd39cee5f Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Thu, 8 Jul 2010 17:06:03 -0700 Subject: [PATCH] Consolidate the various HTTP header dictionary classes into one, which includes better handling of headers with repeated values (e.g. Set-Cookie) --- tornado/httpclient.py | 32 +++++----- tornado/httpserver.py | 27 ++------ tornado/httputil.py | 140 ++++++++++++++++++++++++++++++++++++++++++ tornado/wsgi.py | 25 +------- 4 files changed, 162 insertions(+), 62 deletions(-) create mode 100755 tornado/httputil.py diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 4d29da66..842950db 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -24,6 +24,7 @@ import errno import escape import functools import httplib +import httputil import ioloop import logging import pycurl @@ -60,7 +61,7 @@ class HTTPClient(object): if not isinstance(request, HTTPRequest): request = HTTPRequest(url=request, **kwargs) buffer = cStringIO.StringIO() - headers = {} + headers = httputil.HTTPHeaders() try: _curl_setup_request(self._curl, request, buffer, headers) self._curl.perform() @@ -254,7 +255,7 @@ class AsyncHTTPClient(object): curl = self._free_list.pop() (request, callback) = self._requests.popleft() curl.info = { - "headers": {}, + "headers": httputil.HTTPHeaders(), "buffer": cStringIO.StringIO(), "request": request, "callback": callback, @@ -462,7 +463,7 @@ class AsyncHTTPClient2(object): curl = self._free_list.pop() (request, callback) = self._requests.popleft() curl.info = { - "headers": {}, + "headers": httputil.HTTPHeaders(), "buffer": cStringIO.StringIO(), "request": request, "callback": callback, @@ -505,7 +506,7 @@ class AsyncHTTPClient2(object): class HTTPRequest(object): - def __init__(self, url, method="GET", headers={}, body=None, + def __init__(self, url, method="GET", headers=None, body=None, auth_username=None, auth_password=None, connect_timeout=20.0, request_timeout=20.0, if_modified_since=None, follow_redirects=True, @@ -513,6 +514,8 @@ class HTTPRequest(object): network_interface=None, streaming_callback=None, header_callback=None, prepare_curl_callback=None, allow_nonstandard_methods=False): + if headers is None: + headers = httputil.HTTPHeaders() if if_modified_since: timestamp = calendar.timegm(if_modified_since.utctimetuple()) headers["If-Modified-Since"] = email.utils.formatdate( @@ -618,8 +621,13 @@ def _curl_create(max_simultaneous_connections=None): def _curl_setup_request(curl, request, buffer, headers): curl.setopt(pycurl.URL, request.url) - curl.setopt(pycurl.HTTPHEADER, - [_utf8("%s: %s" % i) for i in request.headers.iteritems()]) + # Request headers may be either a regular dict or HTTPHeaders object + if isinstance(request.headers, httputil.HTTPHeaders): + curl.setopt(pycurl.HTTPHEADER, + [_utf8("%s: %s" % i) for i in request.headers.get_all()]) + else: + curl.setopt(pycurl.HTTPHEADER, + [_utf8("%s: %s" % i) for i in request.headers.iteritems()]) if request.header_callback: curl.setopt(pycurl.HEADERFUNCTION, request.header_callback) else: @@ -695,17 +703,7 @@ def _curl_header_callback(headers, header_line): return if header_line == "\r\n": return - parts = header_line.split(":", 1) - if len(parts) != 2: - logging.warning("Invalid HTTP response header line %r", header_line) - return - name = parts[0].strip() - value = parts[1].strip() - if name in headers: - headers[name] = headers[name] + ',' + value - else: - headers[name] = value - + headers.parse_line(header_line) def _curl_debug(debug_type, debug_msg): debug_types = ('I', '<', '>', '<', '>') diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 63131070..ad7ab077 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -19,6 +19,7 @@ import cgi import errno import functools +import httputil import ioloop import iostream import logging @@ -277,7 +278,7 @@ class HTTPConnection(object): method, uri, version = start_line.split(" ") if not version.startswith("HTTP/"): raise Exception("Malformed HTTP version in HTTP Request-Line") - headers = HTTPHeaders.parse(data[eol:]) + headers = httputil.HTTPHeaders.parse(data[eol:]) self._request = HTTPRequest( connection=self, method=method, uri=uri, version=version, headers=headers, remote_ip=self.address[0]) @@ -332,7 +333,7 @@ class HTTPConnection(object): if eoh == -1: logging.warning("multipart/form-data missing headers") continue - headers = HTTPHeaders.parse(part[:eoh]) + headers = httputil.HTTPHeaders.parse(part[:eoh]) name_header = headers.get("Content-Disposition", "") if not name_header.startswith("form-data;") or \ not part.endswith("\r\n"): @@ -380,7 +381,7 @@ class HTTPRequest(object): self.method = method self.uri = uri self.version = version - self.headers = headers or HTTPHeaders() + self.headers = headers or httputil.HTTPHeaders() self.body = body or "" if connection and connection.xheaders: # Squid uses X-Forwarded-For, others use X-Real-Ip @@ -437,23 +438,3 @@ class HTTPRequest(object): return "%s(%s, headers=%s)" % ( self.__class__.__name__, args, dict(self.headers)) - -class HTTPHeaders(dict): - """A dictionary that maintains Http-Header-Case for all keys.""" - def __setitem__(self, name, value): - dict.__setitem__(self, self._normalize_name(name), value) - - def __getitem__(self, name): - return dict.__getitem__(self, self._normalize_name(name)) - - def _normalize_name(self, name): - return "-".join([w.capitalize() for w in name.split("-")]) - - @classmethod - def parse(cls, headers_string): - headers = cls() - for line in headers_string.splitlines(): - if line: - name, value = line.split(":", 1) - headers[name] = value.strip() - return headers diff --git a/tornado/httputil.py b/tornado/httputil.py new file mode 100755 index 00000000..5e563e88 --- /dev/null +++ b/tornado/httputil.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# +# Copyright 2009 Facebook +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""HTTP utility code shared by clients and servers.""" + +class HTTPHeaders(dict): + """A dictionary that maintains Http-Header-Case for all keys. + + Supports multiple values per key via a pair of new methods, + add() and get_list(). The regular dictionary interface returns a single + value per key, with multiple values joined by a comma. + + >>> h = HTTPHeaders({"content-type": "text/html"}) + >>> h.keys() + ['Content-Type'] + >>> h["Content-Type"] + 'text/html' + + >>> h.add("Set-Cookie", "A=B") + >>> h.add("Set-Cookie", "C=D") + >>> h["set-cookie"] + 'A=B,C=D' + >>> h.get_list("set-cookie") + ['A=B', 'C=D'] + + >>> for (k,v) in sorted(h.get_all()): + ... print '%s: %s' % (k,v) + ... + Content-Type: text/html + Set-Cookie: A=B + Set-Cookie: C=D + """ + def __init__(self, *args, **kwargs): + # Don't pass args or kwargs to dict.__init__, as it will bypass + # our __setitem__ + dict.__init__(self) + self._as_list = {} + self.update(*args, **kwargs) + + # new public methods + + def add(self, name, value): + """Adds a new value for the given key.""" + norm_name = HTTPHeaders._normalize_name(name) + if norm_name in self: + # bypass our override of __setitem__ since it modifies _as_list + dict.__setitem__(self, norm_name, self[norm_name] + ',' + value) + self._as_list[norm_name].append(value) + else: + self[norm_name] = value + + def get_list(self, name): + """Returns all values for the given header as a list.""" + norm_name = HTTPHeaders._normalize_name(name) + return self._as_list.get(norm_name, []) + + def get_all(self): + """Returns an iterable of all (name, value) pairs. + + If a header has multiple values, multiple pairs will be + returned with the same name. + """ + for name, list in self._as_list.iteritems(): + for value in list: + yield (name, value) + + def parse_line(self, line): + """Updates the dictionary with a single header line. + + >>> h = HTTPHeaders() + >>> h.parse_line("Content-Type: text/html") + >>> h.get('content-type') + 'text/html' + """ + name, value = line.split(":", 1) + self.add(name, value.strip()) + + @classmethod + def parse(cls, headers): + """Returns a dictionary from HTTP header text. + + >>> h = HTTPHeaders.parse("Content-Type: text/html\\r\\nContent-Length: 42\\r\\n") + >>> sorted(h.iteritems()) + [('Content-Length', '42'), ('Content-Type', 'text/html')] + """ + h = cls() + for line in headers.splitlines(): + if line: + h.parse_line(line) + return h + + # dict implementation overrides + + def __setitem__(self, name, value): + norm_name = HTTPHeaders._normalize_name(name) + dict.__setitem__(self, norm_name, value) + self._as_list[norm_name] = [value] + + def __getitem__(self, name): + return dict.__getitem__(self, HTTPHeaders._normalize_name(name)) + + def __delitem__(self, name): + norm_name = HTTPHeaders._normalize_name(name) + dict.__delitem__(self, norm_name) + del self._as_list[norm_name] + + def get(self, name, default=None): + return dict.get(self, HTTPHeaders._normalize_name(name), default) + + def update(self, *args, **kwargs): + # dict.update bypasses our __setitem__ + for k, v in dict(*args, **kwargs).iteritems(): + self[k] = v + + @staticmethod + def _normalize_name(name): + """Converts a name to Http-Header-Case. + + >>> HTTPHeaders._normalize_name("coNtent-TYPE") + 'Content-Type' + """ + return "-".join([w.capitalize() for w in name.split("-")]) + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/tornado/wsgi.py b/tornado/wsgi.py index 4aaa5fb7..de356696 100644 --- a/tornado/wsgi.py +++ b/tornado/wsgi.py @@ -54,6 +54,7 @@ import cgi import cStringIO import escape import httplib +import httputil import logging import sys import time @@ -100,7 +101,7 @@ class HTTPRequest(object): values = [v for v in values if v] if values: self.arguments[name] = values self.version = "HTTP/1.1" - self.headers = HTTPHeaders() + self.headers = httputil.HTTPHeaders() if environ.get("CONTENT_TYPE"): self.headers["Content-Type"] = environ["CONTENT_TYPE"] if environ.get("CONTENT_LENGTH"): @@ -164,7 +165,7 @@ class HTTPRequest(object): if eoh == -1: logging.warning("multipart/form-data missing headers") continue - headers = HTTPHeaders.parse(part[:eoh]) + headers = httputil.HTTPHeaders.parse(part[:eoh]) name_header = headers.get("Content-Disposition", "") if not name_header.startswith("form-data;") or \ not part.endswith("\r\n"): @@ -293,23 +294,3 @@ class WSGIContainer(object): request.remote_ip + ")" log_method("%d %s %.2fms", status_code, summary, request_time) - -class HTTPHeaders(dict): - """A dictionary that maintains Http-Header-Case for all keys.""" - def __setitem__(self, name, value): - dict.__setitem__(self, self._normalize_name(name), value) - - def __getitem__(self, name): - return dict.__getitem__(self, self._normalize_name(name)) - - def _normalize_name(self, name): - return "-".join([w.capitalize() for w in name.split("-")]) - - @classmethod - def parse(cls, headers_string): - headers = cls() - for line in headers_string.splitlines(): - if line: - name, value = line.split(":", 1) - headers[name] = value.strip() - return headers