From 344fb21fcfb81e400be2ae9fdd460a44260e6b05 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Thu, 24 Feb 2011 14:36:32 -0800 Subject: [PATCH] Mark many of our internal string literals as byte literals. Use a function wrapper b() instead of real byte literals so we can continue to support 2.5. --- tornado/httpserver.py | 9 +++-- tornado/iostream.py | 6 ++- tornado/simple_httpclient.py | 37 +++++++++--------- tornado/test/iostream_test.py | 9 +++-- tornado/test/simple_httpclient_test.py | 27 ++++++------- tornado/test/stack_context_test.py | 3 +- tornado/test/web_test.py | 21 +++++----- tornado/util.py | 15 ++++++++ tornado/web.py | 53 ++++++++++++++------------ 9 files changed, 105 insertions(+), 75 deletions(-) diff --git a/tornado/httpserver.py b/tornado/httpserver.py index a9c73973..c3a6f8a3 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -28,6 +28,7 @@ from tornado import httputil from tornado import ioloop from tornado import iostream from tornado import stack_context +from tornado.util import b, bytes_type try: import fcntl @@ -289,7 +290,7 @@ class HTTPConnection(object): # Save stack context here, outside of any request. This keeps # contexts from one request from leaking into the next. self._header_callback = stack_context.wrap(self._on_headers) - self.stream.read_until("\r\n\r\n", self._header_callback) + self.stream.read_until(b("\r\n\r\n"), self._header_callback) def write(self, chunk): assert self._request, "Request closed" @@ -323,11 +324,11 @@ class HTTPConnection(object): if disconnect: self.stream.close() return - self.stream.read_until("\r\n\r\n", self._header_callback) + self.stream.read_until(b("\r\n\r\n"), self._header_callback) def _on_headers(self, data): try: - eol = data.find("\r\n") + eol = data.find(b("\r\n")) start_line = data[:eol] try: method, uri, version = start_line.split(" ") @@ -487,7 +488,7 @@ class HTTPRequest(object): def write(self, chunk): """Writes the given chunk to the response stream.""" - assert isinstance(chunk, str) + assert isinstance(chunk, bytes_type) self.connection.write(chunk) def finish(self): diff --git a/tornado/iostream.py b/tornado/iostream.py index 65f36edd..aab96780 100644 --- a/tornado/iostream.py +++ b/tornado/iostream.py @@ -26,6 +26,7 @@ import sys from tornado import ioloop from tornado import stack_context +from tornado.util import b, bytes_type try: import ssl # Python 2.6+ @@ -141,7 +142,7 @@ class IOStream(object): """Call callback when we read the given number of bytes.""" assert not self._read_callback, "Already reading" if num_bytes == 0: - callback("") + callback(b("")) return self._read_bytes = num_bytes self._read_callback = stack_context.wrap(callback) @@ -161,6 +162,7 @@ class IOStream(object): previously buffered write data and an old write callback, that callback is simply overwritten with this new callback. """ + assert isinstance(data, bytes_type) self._check_closed() self._write_buffer.append(data) self._add_io_state(self.io_loop.WRITE) @@ -517,7 +519,7 @@ def _merge_prefix(deque, size): chunk = chunk[:remaining] prefix.append(chunk) remaining -= len(chunk) - deque.appendleft(''.join(prefix)) + deque.appendleft(b('').join(prefix)) def doctests(): import doctest diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 38b5bc31..03e26e4e 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -2,12 +2,15 @@ from __future__ import with_statement from cStringIO import StringIO +from tornado.escape import utf8 from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient from tornado.httputil import HTTPHeaders from tornado.ioloop import IOLoop from tornado.iostream import IOStream, SSLIOStream from tornado import stack_context +from tornado.util import b +import base64 import collections import contextlib import copy @@ -190,17 +193,17 @@ class _HTTPConnection(object): if "Host" not in self.request.headers: self.request.headers["Host"] = parsed.netloc if self.request.auth_username: - auth = "%s:%s" % (self.request.auth_username, - self.request.auth_password) - self.request.headers["Authorization"] = ("Basic %s" % - auth.encode("base64")) + auth = utf8(self.request.auth_username) + b(":") + \ + utf8(self.request.auth_password) + self.request.headers["Authorization"] = \ + b("Basic ") + base64.b64encode(auth) if self.request.user_agent: self.request.headers["User-Agent"] = self.request.user_agent has_body = self.request.method in ("POST", "PUT") if has_body: assert self.request.body is not None - self.request.headers["Content-Length"] = len( - self.request.body) + self.request.headers["Content-Length"] = str(len( + self.request.body)) else: assert self.request.body is None if (self.request.method == "POST" and @@ -210,14 +213,14 @@ class _HTTPConnection(object): self.request.headers["Accept-Encoding"] = "gzip" req_path = ((parsed.path or '/') + (('?' + parsed.query) if parsed.query else '')) - request_lines = ["%s %s HTTP/1.1" % (self.request.method, - req_path)] + request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method, + req_path))] for k, v in self.request.headers.get_all(): - request_lines.append("%s: %s" % (k, v)) - self.stream.write("\r\n".join(request_lines) + "\r\n\r\n") + request_lines.append(utf8(k) + b(": ") + utf8(v)) + self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n")) if has_body: - self.stream.write(self.request.body) - self.stream.read_until("\r\n\r\n", self._on_headers) + self.stream.write(utf8(self.request.body)) + self.stream.read_until(b("\r\n\r\n"), self._on_headers) @contextlib.contextmanager def cleanup(self): @@ -253,7 +256,7 @@ class _HTTPConnection(object): self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS) if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] - self.stream.read_until("\r\n", self._on_chunk_length) + self.stream.read_until(b("\r\n"), self._on_chunk_length) elif "Content-Length" in self.headers: self.stream.read_bytes(int(self.headers["Content-Length"]), self._on_body) @@ -282,7 +285,7 @@ class _HTTPConnection(object): self.code in (301, 302)): new_request = copy.copy(self.request) new_request.url = urlparse.urljoin(self.request.url, - self.headers["Location"]) + utf8(self.headers["Location"])) new_request.max_redirects -= 1 del new_request.headers["Host"] new_request.original_request = original_request @@ -303,13 +306,13 @@ class _HTTPConnection(object): # all the data has been decompressed, so we don't need to # decompress again in _on_body self._decompressor = None - self._on_body(''.join(self.chunks)) + self._on_body(b('').join(self.chunks)) else: self.stream.read_bytes(length + 2, # chunk ends with \r\n self._on_chunk_data) def _on_chunk_data(self, data): - assert data[-2:] == "\r\n" + assert data[-2:] == b("\r\n") chunk = data[:-2] if self._decompressor: chunk = self._decompressor.decompress(chunk) @@ -317,7 +320,7 @@ class _HTTPConnection(object): self.request.streaming_callback(chunk) else: self.chunks.append(chunk) - self.stream.read_until("\r\n", self._on_chunk_length) + self.stream.read_until(b("\r\n"), self._on_chunk_length) # match_hostname was added to the standard library ssl module in python 3.2. diff --git a/tornado/test/iostream_test.py b/tornado/test/iostream_test.py index 48ccd4a6..835e19b7 100644 --- a/tornado/test/iostream_test.py +++ b/tornado/test/iostream_test.py @@ -1,5 +1,6 @@ from tornado.iostream import IOStream from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port +from tornado.util import b from tornado.web import RequestHandler, Application import socket @@ -15,22 +16,22 @@ class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) - self.stream.write("GET / HTTP/1.0\r\n\r\n") + self.stream.write(b("GET / HTTP/1.0\r\n\r\n")) # normal read self.stream.read_bytes(9, self.stop) data = self.wait() - self.assertEqual(data, "HTTP/1.0 ") + self.assertEqual(data, b("HTTP/1.0 ")) # zero bytes self.stream.read_bytes(0, self.stop) data = self.wait() - self.assertEqual(data, "") + self.assertEqual(data, b("")) # another normal read self.stream.read_bytes(3, self.stop) data = self.wait() - self.assertEqual(data, "200") + self.assertEqual(data, b("200")) def test_connection_refused(self): # When a connection is refused, the connect callback should not diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index 6db079c3..9bc9b42b 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -11,6 +11,7 @@ from contextlib import closing from tornado.ioloop import IOLoop from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port +from tornado.util import b from tornado.web import Application, RequestHandler, asynchronous, url class HelloWorldHandler(RequestHandler): @@ -83,10 +84,10 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): response = self.fetch("/hello") self.assertEqual(response.code, 200) self.assertEqual(response.headers["Content-Type"], "text/plain") - self.assertEqual(response.body, "Hello world!") + self.assertEqual(response.body, b("Hello world!")) response = self.fetch("/hello?name=Ben") - self.assertEqual(response.body, "Hello Ben!") + self.assertEqual(response.body, b("Hello Ben!")) def test_streaming_callback(self): # streaming_callback is also tested in test_chunked @@ -94,29 +95,29 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): response = self.fetch("/hello", streaming_callback=chunks.append) # with streaming_callback, data goes to the callback and not response.body - self.assertEqual(chunks, ["Hello world!"]) + self.assertEqual(chunks, [b("Hello world!")]) self.assertFalse(response.body) def test_post(self): response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar") self.assertEqual(response.code, 200) - self.assertEqual(response.body, "Post arg1: foo, arg2: bar") + self.assertEqual(response.body, b("Post arg1: foo, arg2: bar")) def test_chunked(self): response = self.fetch("/chunk") - self.assertEqual(response.body, "asdfqwer") + self.assertEqual(response.body, b("asdfqwer")) chunks = [] response = self.fetch("/chunk", streaming_callback=chunks.append) - self.assertEqual(chunks, ["asdf", "qwer"]) + self.assertEqual(chunks, [b("asdf"), b("qwer")]) self.assertFalse(response.body) def test_basic_auth(self): self.assertEqual(self.fetch("/auth", auth_username="Aladdin", auth_password="open sesame").body, - "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + b("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")) def test_gzip(self): # All the tests in this file should be using gzip, but this test @@ -126,11 +127,11 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): response = self.fetch("/chunk", use_gzip=False, headers={"Accept-Encoding": "gzip"}) self.assertEqual(response.headers["Content-Encoding"], "gzip") - self.assertNotEqual(response.body, "asdfqwer") + self.assertNotEqual(response.body, b("asdfqwer")) # Our test data gets bigger when gzipped. Oops. :) self.assertEqual(len(response.body), 34) f = gzip.GzipFile(mode="r", fileobj=response.buffer) - self.assertEqual(f.read(), "asdfqwer") + self.assertEqual(f.read(), b("asdfqwer")) def test_connect_timeout(self): # create a socket and bind it to a port, but don't @@ -202,16 +203,16 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): response = self.fetch("/countdown/2") self.assertEqual(200, response.code) - self.assertTrue(response.effective_url.endswith("/countdown/0")) - self.assertEqual("Zero", response.body) + self.assertTrue(response.effective_url.endswith(b("/countdown/0"))) + self.assertEqual(b("Zero"), response.body) def test_max_redirects(self): response = self.fetch("/countdown/5", max_redirects=3) self.assertEqual(302, response.code) # We requested 5, followed three redirects for 4, 3, 2, then the last # unfollowed redirect is to 1. - self.assertTrue(response.request.url.endswith("/countdown/5")) - self.assertTrue(response.effective_url.endswith("/countdown/2")) + self.assertTrue(response.request.url.endswith(b("/countdown/5"))) + self.assertTrue(response.effective_url.endswith(b("/countdown/2"))) self.assertTrue(response.headers["Location"].endswith("/countdown/1")) def test_default_certificates_exist(self): diff --git a/tornado/test/stack_context_test.py b/tornado/test/stack_context_test.py index f6e7421b..a9e20db3 100755 --- a/tornado/test/stack_context_test.py +++ b/tornado/test/stack_context_test.py @@ -3,6 +3,7 @@ from __future__ import with_statement from tornado.stack_context import StackContext, wrap from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, LogTrapTestCase +from tornado.util import b from tornado.web import asynchronous, Application, RequestHandler import contextlib import functools @@ -46,7 +47,7 @@ class HTTPStackContextTest(AsyncHTTPTestCase, LogTrapTestCase): self.http_client.fetch(self.get_url('/'), self.handle_response) self.wait() self.assertEquals(self.response.code, 500) - self.assertTrue('got expected exception' in self.response.body) + self.assertTrue(b('got expected exception') in self.response.body) def handle_response(self, response): self.response = response diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index fccca7dc..8941332d 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -1,8 +1,10 @@ -from tornado.escape import json_decode +from tornado.escape import json_decode, utf8 from tornado.iostream import IOStream from tornado.testing import LogTrapTestCase, AsyncHTTPTestCase +from tornado.util import b from tornado.web import RequestHandler, _O, authenticated, Application, asynchronous +import binascii import logging import re import socket @@ -24,15 +26,15 @@ class CookieTestRequestHandler(RequestHandler): class SecureCookieTest(LogTrapTestCase): def test_round_trip(self): handler = CookieTestRequestHandler() - handler.set_secure_cookie('foo', 'bar') - self.assertEquals(handler.get_secure_cookie('foo'), 'bar') + handler.set_secure_cookie('foo', b('bar')) + self.assertEquals(handler.get_secure_cookie('foo'), b('bar')) def test_cookie_tampering_future_timestamp(self): handler = CookieTestRequestHandler() # this string base64-encodes to '12345678' - handler.set_secure_cookie('foo', '\xd7m\xf8\xe7\xae\xfc') + handler.set_secure_cookie('foo', binascii.a2b_hex(b('d76df8e7aefc'))) cookie = handler._cookies['foo'] - match = re.match(r'12345678\|([0-9]+)\|([0-9a-f]+)', cookie) + match = re.match(b(r'12345678\|([0-9]+)\|([0-9a-f]+)'), cookie) assert match timestamp = match.group(1) sig = match.group(2) @@ -41,10 +43,11 @@ class SecureCookieTest(LogTrapTestCase): # shifting digits from payload to timestamp doesn't alter signature # (this is not desirable behavior, just confirming that that's how it # works) - self.assertEqual(handler._cookie_signature('foo', '1234', - '5678' + timestamp), sig) + self.assertEqual( + handler._cookie_signature('foo', '1234', b('5678') + timestamp), + sig) # tamper with the cookie - handler._cookies['foo'] = '1234|5678%s|%s' % (timestamp, sig) + handler._cookies['foo'] = utf8('1234|5678%s|%s' % (timestamp, sig)) # it gets rejected assert handler.get_secure_cookie('foo') is None @@ -104,7 +107,7 @@ class ConnectionCloseTest(AsyncHTTPTestCase, LogTrapTestCase): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) s.connect(("localhost", self.get_http_port())) self.stream = IOStream(s, io_loop=self.io_loop) - self.stream.write("GET / HTTP/1.0\r\n\r\n") + self.stream.write(b("GET / HTTP/1.0\r\n\r\n")) self.wait() def on_handler_waiting(self): diff --git a/tornado/util.py b/tornado/util.py index 3706b161..964a9f02 100644 --- a/tornado/util.py +++ b/tornado/util.py @@ -15,6 +15,21 @@ def import_object(name): obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0) return getattr(obj, parts[-1]) +# Fake byte literal support: In python 2.6+, you can say b"foo" to get +# a byte literal (str in 2.x, bytes in 3.x). There's no way to do this +# in a way that supports 2.5, though, so we need a function wrapper +# to convert our string literals. b() should only be applied to literal +# ascii strings. Once we drop support for 2.5, we can remove this function +# and just use byte literals. +if str is unicode: + def b(s): + return s.encode('ascii') + bytes_type = bytes +else: + def b(s): + return s + bytes_type = str + def doctests(): import doctest return doctest.DocTestSuite() diff --git a/tornado/web.py b/tornado/web.py index f7812791..2a3b3ac9 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -82,6 +82,7 @@ from tornado import locale from tornado import stack_context from tornado import template from tornado.escape import utf8 +from tornado.util import b, bytes_type class RequestHandler(object): """Subclass this class and define get() or post() to make a handler. @@ -211,7 +212,7 @@ class RequestHandler(object): # If \n is allowed into the header, it is possible to inject # additional headers or split the request. Also cap length to # prevent obviously erroneous values. - safe_value = re.sub(r"[\x00-\x1f]", " ", value)[:4000] + safe_value = re.sub(b(r"[\x00-\x1f]"), b(" "), value)[:4000] if safe_value != value: raise ValueError("Unsafe header value %r", value) self._headers[name] = value @@ -244,8 +245,8 @@ class RequestHandler(object): """ values = self.request.arguments.get(name, []) # Get rid of any weird control chars - values = [re.sub(r"[\x00-\x08\x0e-\x1f]", " ", x) for x in values] - values = [_unicode(x) for x in values] + values = [re.sub(r"[\x00-\x08\x0e-\x1f]", " ", _unicode(x)) + for x in values] if strip: values = [x.strip() for x in values] return values @@ -332,10 +333,10 @@ class RequestHandler(object): method for non-cookie uses. To decode a value not stored as a cookie use the optional value argument to get_secure_cookie. """ - timestamp = str(int(time.time())) - value = base64.b64encode(value) + timestamp = utf8(str(int(time.time()))) + value = base64.b64encode(utf8(value)) signature = self._cookie_signature(name, value, timestamp) - value = "|".join([value, timestamp, signature]) + value = b("|").join([value, timestamp, signature]) return value def get_secure_cookie(self, name, include_name=True, value=None): @@ -350,7 +351,7 @@ class RequestHandler(object): """ if value is None: value = self.get_cookie(name) if not value: return None - parts = value.split("|") + parts = value.split(b("|")) if len(parts) != 3: return None if include_name: signature = self._cookie_signature(name, parts[0], parts[1]) @@ -371,7 +372,7 @@ class RequestHandler(object): # here instead of modifying _cookie_signature. logging.warning("Cookie timestamp in future; possible tampering %r", value) return None - if parts[1].startswith("0"): + if parts[1].startswith(b("0")): logging.warning("Tampered cookie %r", value) try: return base64.b64decode(parts[0]) @@ -380,10 +381,10 @@ class RequestHandler(object): def _cookie_signature(self, *parts): self.require_setting("cookie_secret", "secure cookies") - hash = hmac.new(self.application.settings["cookie_secret"], + hash = hmac.new(utf8(self.application.settings["cookie_secret"]), digestmod=hashlib.sha1) - for part in parts: hash.update(part) - return hash.hexdigest() + for part in parts: hash.update(utf8(part)) + return utf8(hash.hexdigest()) def redirect(self, url, permanent=False): """Sends a redirect to the given (optionally relative) URL.""" @@ -391,8 +392,9 @@ class RequestHandler(object): raise Exception("Cannot redirect after headers have been written") self.set_status(301 if permanent else 302) # Remove whitespace - url = re.sub(r"[\x00-\x20]+", "", utf8(url)) - self.set_header("Location", urlparse.urljoin(self.request.uri, url)) + url = re.sub(b(r"[\x00-\x20]+"), "", utf8(url)) + self.set_header("Location", urlparse.urljoin(utf8(self.request.uri), + url)) self.finish() def write(self, chunk): @@ -534,7 +536,7 @@ class RequestHandler(object): if self.application._wsgi: raise Exception("WSGI applications do not support flush()") - chunk = "".join(self._write_buffer) + chunk = b("").join(self._write_buffer) self._write_buffer = [] if not self._headers_written: self._headers_written = True @@ -545,7 +547,7 @@ class RequestHandler(object): else: for transform in self._transforms: chunk = transform.transform_chunk(chunk, include_footers) - headers = "" + headers = b("") # Ignore the chunk and only write the headers for HEAD requests if self.request.method == "HEAD": @@ -864,13 +866,14 @@ class RequestHandler(object): self.finish() def _generate_headers(self): - lines = [self.request.version + " " + str(self._status_code) + " " + - httplib.responses[self._status_code]] - lines.extend(["%s: %s" % (n, v) for n, v in self._headers.iteritems()]) + lines = [utf8(self.request.version + " " + + str(self._status_code) + + " " + httplib.responses[self._status_code])] + lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in self._headers.iteritems()]) for cookie_dict in getattr(self, "_new_cookies", []): for cookie in cookie_dict.values(): - lines.append("Set-Cookie: " + cookie.OutputString(None)) - return "\r\n".join(lines) + "\r\n\r\n" + lines.append(b("Set-Cookie: ") + cookie.OutputString(None)) + return b("\r\n").join(lines) + b("\r\n\r\n") def _log(self): """Logs the current request. @@ -882,8 +885,8 @@ class RequestHandler(object): self.application.log_request(self) def _request_summary(self): - return self.request.method + " " + self.request.uri + " (" + \ - self.request.remote_ip + ")" + return utf8(self.request.method) + b(" ") + utf8(self.request.uri) + \ + b(" (") + utf8(self.request.remote_ip) + b(")") def _handle_request_exception(self, e): if isinstance(e, HTTPError): @@ -1472,9 +1475,9 @@ class ChunkedTransferEncoding(OutputTransform): # Don't write out empty chunks because that means END-OF-STREAM # with chunked encoding if block: - block = ("%x" % len(block)) + "\r\n" + block + "\r\n" + block = utf8("%x" % len(block)) + b("\r\n") + block + b("\r\n") if finishing: - block += "0\r\n\r\n" + block += b("0\r\n\r\n") return block @@ -1607,7 +1610,7 @@ url = URLSpec def _unicode(s): - if isinstance(s, str): + if isinstance(s, bytes_type): try: return s.decode("utf-8") except UnicodeDecodeError: