Mark many of our internal string literals as byte literals.

Use a function wrapper b() instead of real byte literals so we
can continue to support 2.5.
This commit is contained in:
Ben Darnell 2011-02-24 14:36:32 -08:00
parent a3ee8cf69f
commit 344fb21fcf
9 changed files with 105 additions and 75 deletions

View File

@ -28,6 +28,7 @@ from tornado import httputil
from tornado import ioloop
from tornado import iostream
from tornado import stack_context
from tornado.util import b, bytes_type
try:
import fcntl
@ -289,7 +290,7 @@ class HTTPConnection(object):
# Save stack context here, outside of any request. This keeps
# contexts from one request from leaking into the next.
self._header_callback = stack_context.wrap(self._on_headers)
self.stream.read_until("\r\n\r\n", self._header_callback)
self.stream.read_until(b("\r\n\r\n"), self._header_callback)
def write(self, chunk):
assert self._request, "Request closed"
@ -323,11 +324,11 @@ class HTTPConnection(object):
if disconnect:
self.stream.close()
return
self.stream.read_until("\r\n\r\n", self._header_callback)
self.stream.read_until(b("\r\n\r\n"), self._header_callback)
def _on_headers(self, data):
try:
eol = data.find("\r\n")
eol = data.find(b("\r\n"))
start_line = data[:eol]
try:
method, uri, version = start_line.split(" ")
@ -487,7 +488,7 @@ class HTTPRequest(object):
def write(self, chunk):
"""Writes the given chunk to the response stream."""
assert isinstance(chunk, str)
assert isinstance(chunk, bytes_type)
self.connection.write(chunk)
def finish(self):

View File

@ -26,6 +26,7 @@ import sys
from tornado import ioloop
from tornado import stack_context
from tornado.util import b, bytes_type
try:
import ssl # Python 2.6+
@ -141,7 +142,7 @@ class IOStream(object):
"""Call callback when we read the given number of bytes."""
assert not self._read_callback, "Already reading"
if num_bytes == 0:
callback("")
callback(b(""))
return
self._read_bytes = num_bytes
self._read_callback = stack_context.wrap(callback)
@ -161,6 +162,7 @@ class IOStream(object):
previously buffered write data and an old write callback, that
callback is simply overwritten with this new callback.
"""
assert isinstance(data, bytes_type)
self._check_closed()
self._write_buffer.append(data)
self._add_io_state(self.io_loop.WRITE)
@ -517,7 +519,7 @@ def _merge_prefix(deque, size):
chunk = chunk[:remaining]
prefix.append(chunk)
remaining -= len(chunk)
deque.appendleft(''.join(prefix))
deque.appendleft(b('').join(prefix))
def doctests():
import doctest

View File

@ -2,12 +2,15 @@
from __future__ import with_statement
from cStringIO import StringIO
from tornado.escape import utf8
from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
from tornado import stack_context
from tornado.util import b
import base64
import collections
import contextlib
import copy
@ -190,17 +193,17 @@ class _HTTPConnection(object):
if "Host" not in self.request.headers:
self.request.headers["Host"] = parsed.netloc
if self.request.auth_username:
auth = "%s:%s" % (self.request.auth_username,
self.request.auth_password)
self.request.headers["Authorization"] = ("Basic %s" %
auth.encode("base64"))
auth = utf8(self.request.auth_username) + b(":") + \
utf8(self.request.auth_password)
self.request.headers["Authorization"] = \
b("Basic ") + base64.b64encode(auth)
if self.request.user_agent:
self.request.headers["User-Agent"] = self.request.user_agent
has_body = self.request.method in ("POST", "PUT")
if has_body:
assert self.request.body is not None
self.request.headers["Content-Length"] = len(
self.request.body)
self.request.headers["Content-Length"] = str(len(
self.request.body))
else:
assert self.request.body is None
if (self.request.method == "POST" and
@ -210,14 +213,14 @@ class _HTTPConnection(object):
self.request.headers["Accept-Encoding"] = "gzip"
req_path = ((parsed.path or '/') +
(('?' + parsed.query) if parsed.query else ''))
request_lines = ["%s %s HTTP/1.1" % (self.request.method,
req_path)]
request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
req_path))]
for k, v in self.request.headers.get_all():
request_lines.append("%s: %s" % (k, v))
self.stream.write("\r\n".join(request_lines) + "\r\n\r\n")
request_lines.append(utf8(k) + b(": ") + utf8(v))
self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
if has_body:
self.stream.write(self.request.body)
self.stream.read_until("\r\n\r\n", self._on_headers)
self.stream.write(utf8(self.request.body))
self.stream.read_until(b("\r\n\r\n"), self._on_headers)
@contextlib.contextmanager
def cleanup(self):
@ -253,7 +256,7 @@ class _HTTPConnection(object):
self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
if self.headers.get("Transfer-Encoding") == "chunked":
self.chunks = []
self.stream.read_until("\r\n", self._on_chunk_length)
self.stream.read_until(b("\r\n"), self._on_chunk_length)
elif "Content-Length" in self.headers:
self.stream.read_bytes(int(self.headers["Content-Length"]),
self._on_body)
@ -282,7 +285,7 @@ class _HTTPConnection(object):
self.code in (301, 302)):
new_request = copy.copy(self.request)
new_request.url = urlparse.urljoin(self.request.url,
self.headers["Location"])
utf8(self.headers["Location"]))
new_request.max_redirects -= 1
del new_request.headers["Host"]
new_request.original_request = original_request
@ -303,13 +306,13 @@ class _HTTPConnection(object):
# all the data has been decompressed, so we don't need to
# decompress again in _on_body
self._decompressor = None
self._on_body(''.join(self.chunks))
self._on_body(b('').join(self.chunks))
else:
self.stream.read_bytes(length + 2, # chunk ends with \r\n
self._on_chunk_data)
def _on_chunk_data(self, data):
assert data[-2:] == "\r\n"
assert data[-2:] == b("\r\n")
chunk = data[:-2]
if self._decompressor:
chunk = self._decompressor.decompress(chunk)
@ -317,7 +320,7 @@ class _HTTPConnection(object):
self.request.streaming_callback(chunk)
else:
self.chunks.append(chunk)
self.stream.read_until("\r\n", self._on_chunk_length)
self.stream.read_until(b("\r\n"), self._on_chunk_length)
# match_hostname was added to the standard library ssl module in python 3.2.

View File

@ -1,5 +1,6 @@
from tornado.iostream import IOStream
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
from tornado.util import b
from tornado.web import RequestHandler, Application
import socket
@ -15,22 +16,22 @@ class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
self.stream = IOStream(s, io_loop=self.io_loop)
self.stream.write("GET / HTTP/1.0\r\n\r\n")
self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))
# normal read
self.stream.read_bytes(9, self.stop)
data = self.wait()
self.assertEqual(data, "HTTP/1.0 ")
self.assertEqual(data, b("HTTP/1.0 "))
# zero bytes
self.stream.read_bytes(0, self.stop)
data = self.wait()
self.assertEqual(data, "")
self.assertEqual(data, b(""))
# another normal read
self.stream.read_bytes(3, self.stop)
data = self.wait()
self.assertEqual(data, "200")
self.assertEqual(data, b("200"))
def test_connection_refused(self):
# When a connection is refused, the connect callback should not

View File

@ -11,6 +11,7 @@ from contextlib import closing
from tornado.ioloop import IOLoop
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
from tornado.util import b
from tornado.web import Application, RequestHandler, asynchronous, url
class HelloWorldHandler(RequestHandler):
@ -83,10 +84,10 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, "Hello world!")
self.assertEqual(response.body, b("Hello world!"))
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, "Hello Ben!")
self.assertEqual(response.body, b("Hello Ben!"))
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
@ -94,29 +95,29 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
response = self.fetch("/hello",
streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, ["Hello world!"])
self.assertEqual(chunks, [b("Hello world!")])
self.assertFalse(response.body)
def test_post(self):
response = self.fetch("/post", method="POST",
body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, "Post arg1: foo, arg2: bar")
self.assertEqual(response.body, b("Post arg1: foo, arg2: bar"))
def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, "asdfqwer")
self.assertEqual(response.body, b("asdfqwer"))
chunks = []
response = self.fetch("/chunk",
streaming_callback=chunks.append)
self.assertEqual(chunks, ["asdf", "qwer"])
self.assertEqual(chunks, [b("asdf"), b("qwer")])
self.assertFalse(response.body)
def test_basic_auth(self):
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
auth_password="open sesame").body,
"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
b("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="))
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
@ -126,11 +127,11 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
response = self.fetch("/chunk", use_gzip=False,
headers={"Accept-Encoding": "gzip"})
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, "asdfqwer")
self.assertNotEqual(response.body, b("asdfqwer"))
# Our test data gets bigger when gzipped. Oops. :)
self.assertEqual(len(response.body), 34)
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
self.assertEqual(f.read(), "asdfqwer")
self.assertEqual(f.read(), b("asdfqwer"))
def test_connect_timeout(self):
# create a socket and bind it to a port, but don't
@ -202,16 +203,16 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
response = self.fetch("/countdown/2")
self.assertEqual(200, response.code)
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual("Zero", response.body)
self.assertTrue(response.effective_url.endswith(b("/countdown/0")))
self.assertEqual(b("Zero"), response.body)
def test_max_redirects(self):
response = self.fetch("/countdown/5", max_redirects=3)
self.assertEqual(302, response.code)
# We requested 5, followed three redirects for 4, 3, 2, then the last
# unfollowed redirect is to 1.
self.assertTrue(response.request.url.endswith("/countdown/5"))
self.assertTrue(response.effective_url.endswith("/countdown/2"))
self.assertTrue(response.request.url.endswith(b("/countdown/5")))
self.assertTrue(response.effective_url.endswith(b("/countdown/2")))
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
def test_default_certificates_exist(self):

View File

@ -3,6 +3,7 @@ from __future__ import with_statement
from tornado.stack_context import StackContext, wrap
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, LogTrapTestCase
from tornado.util import b
from tornado.web import asynchronous, Application, RequestHandler
import contextlib
import functools
@ -46,7 +47,7 @@ class HTTPStackContextTest(AsyncHTTPTestCase, LogTrapTestCase):
self.http_client.fetch(self.get_url('/'), self.handle_response)
self.wait()
self.assertEquals(self.response.code, 500)
self.assertTrue('got expected exception' in self.response.body)
self.assertTrue(b('got expected exception') in self.response.body)
def handle_response(self, response):
self.response = response

View File

@ -1,8 +1,10 @@
from tornado.escape import json_decode
from tornado.escape import json_decode, utf8
from tornado.iostream import IOStream
from tornado.testing import LogTrapTestCase, AsyncHTTPTestCase
from tornado.util import b
from tornado.web import RequestHandler, _O, authenticated, Application, asynchronous
import binascii
import logging
import re
import socket
@ -24,15 +26,15 @@ class CookieTestRequestHandler(RequestHandler):
class SecureCookieTest(LogTrapTestCase):
def test_round_trip(self):
handler = CookieTestRequestHandler()
handler.set_secure_cookie('foo', 'bar')
self.assertEquals(handler.get_secure_cookie('foo'), 'bar')
handler.set_secure_cookie('foo', b('bar'))
self.assertEquals(handler.get_secure_cookie('foo'), b('bar'))
def test_cookie_tampering_future_timestamp(self):
handler = CookieTestRequestHandler()
# this string base64-encodes to '12345678'
handler.set_secure_cookie('foo', '\xd7m\xf8\xe7\xae\xfc')
handler.set_secure_cookie('foo', binascii.a2b_hex(b('d76df8e7aefc')))
cookie = handler._cookies['foo']
match = re.match(r'12345678\|([0-9]+)\|([0-9a-f]+)', cookie)
match = re.match(b(r'12345678\|([0-9]+)\|([0-9a-f]+)'), cookie)
assert match
timestamp = match.group(1)
sig = match.group(2)
@ -41,10 +43,11 @@ class SecureCookieTest(LogTrapTestCase):
# shifting digits from payload to timestamp doesn't alter signature
# (this is not desirable behavior, just confirming that that's how it
# works)
self.assertEqual(handler._cookie_signature('foo', '1234',
'5678' + timestamp), sig)
self.assertEqual(
handler._cookie_signature('foo', '1234', b('5678') + timestamp),
sig)
# tamper with the cookie
handler._cookies['foo'] = '1234|5678%s|%s' % (timestamp, sig)
handler._cookies['foo'] = utf8('1234|5678%s|%s' % (timestamp, sig))
# it gets rejected
assert handler.get_secure_cookie('foo') is None
@ -104,7 +107,7 @@ class ConnectionCloseTest(AsyncHTTPTestCase, LogTrapTestCase):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
s.connect(("localhost", self.get_http_port()))
self.stream = IOStream(s, io_loop=self.io_loop)
self.stream.write("GET / HTTP/1.0\r\n\r\n")
self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))
self.wait()
def on_handler_waiting(self):

View File

@ -15,6 +15,21 @@ def import_object(name):
obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0)
return getattr(obj, parts[-1])
# Fake byte literal support: In python 2.6+, you can say b"foo" to get
# a byte literal (str in 2.x, bytes in 3.x). There's no way to do this
# in a way that supports 2.5, though, so we need a function wrapper
# to convert our string literals. b() should only be applied to literal
# ascii strings. Once we drop support for 2.5, we can remove this function
# and just use byte literals.
if str is unicode:
def b(s):
return s.encode('ascii')
bytes_type = bytes
else:
def b(s):
return s
bytes_type = str
def doctests():
import doctest
return doctest.DocTestSuite()

View File

@ -82,6 +82,7 @@ from tornado import locale
from tornado import stack_context
from tornado import template
from tornado.escape import utf8
from tornado.util import b, bytes_type
class RequestHandler(object):
"""Subclass this class and define get() or post() to make a handler.
@ -211,7 +212,7 @@ class RequestHandler(object):
# If \n is allowed into the header, it is possible to inject
# additional headers or split the request. Also cap length to
# prevent obviously erroneous values.
safe_value = re.sub(r"[\x00-\x1f]", " ", value)[:4000]
safe_value = re.sub(b(r"[\x00-\x1f]"), b(" "), value)[:4000]
if safe_value != value:
raise ValueError("Unsafe header value %r", value)
self._headers[name] = value
@ -244,8 +245,8 @@ class RequestHandler(object):
"""
values = self.request.arguments.get(name, [])
# Get rid of any weird control chars
values = [re.sub(r"[\x00-\x08\x0e-\x1f]", " ", x) for x in values]
values = [_unicode(x) for x in values]
values = [re.sub(r"[\x00-\x08\x0e-\x1f]", " ", _unicode(x))
for x in values]
if strip:
values = [x.strip() for x in values]
return values
@ -332,10 +333,10 @@ class RequestHandler(object):
method for non-cookie uses. To decode a value not stored
as a cookie use the optional value argument to get_secure_cookie.
"""
timestamp = str(int(time.time()))
value = base64.b64encode(value)
timestamp = utf8(str(int(time.time())))
value = base64.b64encode(utf8(value))
signature = self._cookie_signature(name, value, timestamp)
value = "|".join([value, timestamp, signature])
value = b("|").join([value, timestamp, signature])
return value
def get_secure_cookie(self, name, include_name=True, value=None):
@ -350,7 +351,7 @@ class RequestHandler(object):
"""
if value is None: value = self.get_cookie(name)
if not value: return None
parts = value.split("|")
parts = value.split(b("|"))
if len(parts) != 3: return None
if include_name:
signature = self._cookie_signature(name, parts[0], parts[1])
@ -371,7 +372,7 @@ class RequestHandler(object):
# here instead of modifying _cookie_signature.
logging.warning("Cookie timestamp in future; possible tampering %r", value)
return None
if parts[1].startswith("0"):
if parts[1].startswith(b("0")):
logging.warning("Tampered cookie %r", value)
try:
return base64.b64decode(parts[0])
@ -380,10 +381,10 @@ class RequestHandler(object):
def _cookie_signature(self, *parts):
self.require_setting("cookie_secret", "secure cookies")
hash = hmac.new(self.application.settings["cookie_secret"],
hash = hmac.new(utf8(self.application.settings["cookie_secret"]),
digestmod=hashlib.sha1)
for part in parts: hash.update(part)
return hash.hexdigest()
for part in parts: hash.update(utf8(part))
return utf8(hash.hexdigest())
def redirect(self, url, permanent=False):
"""Sends a redirect to the given (optionally relative) URL."""
@ -391,8 +392,9 @@ class RequestHandler(object):
raise Exception("Cannot redirect after headers have been written")
self.set_status(301 if permanent else 302)
# Remove whitespace
url = re.sub(r"[\x00-\x20]+", "", utf8(url))
self.set_header("Location", urlparse.urljoin(self.request.uri, url))
url = re.sub(b(r"[\x00-\x20]+"), "", utf8(url))
self.set_header("Location", urlparse.urljoin(utf8(self.request.uri),
url))
self.finish()
def write(self, chunk):
@ -534,7 +536,7 @@ class RequestHandler(object):
if self.application._wsgi:
raise Exception("WSGI applications do not support flush()")
chunk = "".join(self._write_buffer)
chunk = b("").join(self._write_buffer)
self._write_buffer = []
if not self._headers_written:
self._headers_written = True
@ -545,7 +547,7 @@ class RequestHandler(object):
else:
for transform in self._transforms:
chunk = transform.transform_chunk(chunk, include_footers)
headers = ""
headers = b("")
# Ignore the chunk and only write the headers for HEAD requests
if self.request.method == "HEAD":
@ -864,13 +866,14 @@ class RequestHandler(object):
self.finish()
def _generate_headers(self):
lines = [self.request.version + " " + str(self._status_code) + " " +
httplib.responses[self._status_code]]
lines.extend(["%s: %s" % (n, v) for n, v in self._headers.iteritems()])
lines = [utf8(self.request.version + " " +
str(self._status_code) +
" " + httplib.responses[self._status_code])]
lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in self._headers.iteritems()])
for cookie_dict in getattr(self, "_new_cookies", []):
for cookie in cookie_dict.values():
lines.append("Set-Cookie: " + cookie.OutputString(None))
return "\r\n".join(lines) + "\r\n\r\n"
lines.append(b("Set-Cookie: ") + cookie.OutputString(None))
return b("\r\n").join(lines) + b("\r\n\r\n")
def _log(self):
"""Logs the current request.
@ -882,8 +885,8 @@ class RequestHandler(object):
self.application.log_request(self)
def _request_summary(self):
return self.request.method + " " + self.request.uri + " (" + \
self.request.remote_ip + ")"
return utf8(self.request.method) + b(" ") + utf8(self.request.uri) + \
b(" (") + utf8(self.request.remote_ip) + b(")")
def _handle_request_exception(self, e):
if isinstance(e, HTTPError):
@ -1472,9 +1475,9 @@ class ChunkedTransferEncoding(OutputTransform):
# Don't write out empty chunks because that means END-OF-STREAM
# with chunked encoding
if block:
block = ("%x" % len(block)) + "\r\n" + block + "\r\n"
block = utf8("%x" % len(block)) + b("\r\n") + block + b("\r\n")
if finishing:
block += "0\r\n\r\n"
block += b("0\r\n\r\n")
return block
@ -1607,7 +1610,7 @@ url = URLSpec
def _unicode(s):
if isinstance(s, str):
if isinstance(s, bytes_type):
try:
return s.decode("utf-8")
except UnicodeDecodeError: