Merge branch 'python3'
This commit is contained in:
commit
3b3534c243
|
@ -74,7 +74,7 @@ def json_encode(value):
|
|||
|
||||
def json_decode(value):
|
||||
"""Returns Python objects for the given JSON string."""
|
||||
return _json_decode(value)
|
||||
return _json_decode(_unicode(value))
|
||||
|
||||
|
||||
def squeeze(value):
|
||||
|
@ -92,13 +92,12 @@ def url_unescape(value):
|
|||
return _unicode(urllib.unquote_plus(value))
|
||||
|
||||
|
||||
_UTF8_TYPES = (bytes, type(None))
|
||||
def utf8(value):
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, unicode):
|
||||
return value.encode("utf-8")
|
||||
assert isinstance(value, bytes)
|
||||
return value
|
||||
if isinstance(value, _UTF8_TYPES):
|
||||
return value
|
||||
assert isinstance(value, unicode)
|
||||
return value.encode("utf-8")
|
||||
|
||||
|
||||
# I originally used the regex from
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
"""A non-blocking, single-threaded HTTP server."""
|
||||
|
||||
import cgi
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
|
@ -28,6 +27,12 @@ from tornado import httputil
|
|||
from tornado import ioloop
|
||||
from tornado import iostream
|
||||
from tornado import stack_context
|
||||
from tornado.util import b, bytes_type
|
||||
|
||||
try:
|
||||
from urlparse import parse_qs # Python 2.6+
|
||||
except ImportError:
|
||||
from cgi import parse_qs
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
|
@ -289,7 +294,7 @@ class HTTPConnection(object):
|
|||
# Save stack context here, outside of any request. This keeps
|
||||
# contexts from one request from leaking into the next.
|
||||
self._header_callback = stack_context.wrap(self._on_headers)
|
||||
self.stream.read_until("\r\n\r\n", self._header_callback)
|
||||
self.stream.read_until(b("\r\n\r\n"), self._header_callback)
|
||||
|
||||
def write(self, chunk):
|
||||
assert self._request, "Request closed"
|
||||
|
@ -323,10 +328,11 @@ class HTTPConnection(object):
|
|||
if disconnect:
|
||||
self.stream.close()
|
||||
return
|
||||
self.stream.read_until("\r\n\r\n", self._header_callback)
|
||||
self.stream.read_until(b("\r\n\r\n"), self._header_callback)
|
||||
|
||||
def _on_headers(self, data):
|
||||
try:
|
||||
data = data.decode('latin1')
|
||||
eol = data.find("\r\n")
|
||||
start_line = data[:eol]
|
||||
try:
|
||||
|
@ -362,8 +368,9 @@ class HTTPConnection(object):
|
|||
content_type = self._request.headers.get("Content-Type", "")
|
||||
if self._request.method in ("POST", "PUT"):
|
||||
if content_type.startswith("application/x-www-form-urlencoded"):
|
||||
arguments = cgi.parse_qs(self._request.body)
|
||||
arguments = parse_qs(self._request.body)
|
||||
for name, values in arguments.iteritems():
|
||||
name = name.decode('utf-8')
|
||||
values = [v for v in values if v]
|
||||
if values:
|
||||
self._request.arguments.setdefault(name, []).extend(
|
||||
|
@ -412,7 +419,7 @@ class HTTPConnection(object):
|
|||
if not name_values.get("name"):
|
||||
logging.warning("multipart/form-data value missing name")
|
||||
continue
|
||||
name = name_values["name"]
|
||||
name = name_values["name"].decode("utf-8")
|
||||
if name_values.get("filename"):
|
||||
ctype = headers.get("Content-Type", "application/unknown")
|
||||
self._request.files.setdefault(name, []).append(dict(
|
||||
|
@ -475,7 +482,7 @@ class HTTPRequest(object):
|
|||
scheme, netloc, path, query, fragment = urlparse.urlsplit(uri)
|
||||
self.path = path
|
||||
self.query = query
|
||||
arguments = cgi.parse_qs(query)
|
||||
arguments = parse_qs(query)
|
||||
self.arguments = {}
|
||||
for name, values in arguments.iteritems():
|
||||
values = [v for v in values if v]
|
||||
|
@ -487,7 +494,7 @@ class HTTPRequest(object):
|
|||
|
||||
def write(self, chunk):
|
||||
"""Writes the given chunk to the response stream."""
|
||||
assert isinstance(chunk, str)
|
||||
assert isinstance(chunk, bytes_type)
|
||||
self.connection.write(chunk)
|
||||
|
||||
def finish(self):
|
||||
|
|
|
@ -27,6 +27,7 @@ import sys
|
|||
|
||||
from tornado import ioloop
|
||||
from tornado import stack_context
|
||||
from tornado.util import b, bytes_type
|
||||
|
||||
try:
|
||||
import ssl # Python 2.6+
|
||||
|
@ -142,7 +143,7 @@ class IOStream(object):
|
|||
"""Call callback when we read the given number of bytes."""
|
||||
assert not self._read_callback, "Already reading"
|
||||
if num_bytes == 0:
|
||||
callback("")
|
||||
callback(b(""))
|
||||
return
|
||||
self._read_bytes = num_bytes
|
||||
self._read_callback = stack_context.wrap(callback)
|
||||
|
@ -162,6 +163,7 @@ class IOStream(object):
|
|||
previously buffered write data and an old write callback, that
|
||||
callback is simply overwritten with this new callback.
|
||||
"""
|
||||
assert isinstance(data, bytes_type)
|
||||
self._check_closed()
|
||||
self._write_buffer.append(data)
|
||||
self._add_io_state(self.io_loop.WRITE)
|
||||
|
@ -527,7 +529,13 @@ def _merge_prefix(deque, size):
|
|||
chunk = chunk[:remaining]
|
||||
prefix.append(chunk)
|
||||
remaining -= len(chunk)
|
||||
deque.appendleft(''.join(prefix))
|
||||
# This data structure normally just contains byte strings, but
|
||||
# the unittest gets messy if it doesn't use the default str() type,
|
||||
# so do the merge based on the type of data that's actually present.
|
||||
if prefix:
|
||||
deque.appendleft(type(prefix[0])().join(prefix))
|
||||
if not deque:
|
||||
deque.appendleft(b(""))
|
||||
|
||||
def doctests():
|
||||
import doctest
|
||||
|
|
|
@ -55,6 +55,8 @@ import re
|
|||
import sys
|
||||
import time
|
||||
|
||||
from tornado.escape import _unicode
|
||||
|
||||
# For pretty log messages, if available
|
||||
try:
|
||||
import curses
|
||||
|
@ -300,7 +302,7 @@ class _Option(object):
|
|||
return value.lower() not in ("false", "0", "f")
|
||||
|
||||
def _parse_string(self, value):
|
||||
return value.decode("utf-8")
|
||||
return _unicode(value)
|
||||
|
||||
|
||||
class Error(Exception):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import with_statement
|
||||
|
||||
from cStringIO import StringIO
|
||||
from tornado.escape import utf8, _unicode
|
||||
from tornado.httpclient import HTTPRequest, HTTPResponse, HTTPError, AsyncHTTPClient
|
||||
from tornado.httputil import HTTPHeaders
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream, SSLIOStream
|
||||
from tornado import stack_context
|
||||
from tornado.util import b
|
||||
|
||||
import base64
|
||||
import collections
|
||||
|
@ -22,6 +23,11 @@ import time
|
|||
import urlparse
|
||||
import zlib
|
||||
|
||||
try:
|
||||
from io import BytesIO # python 3
|
||||
except ImportError:
|
||||
from cStringIO import StringIO as BytesIO # python 2
|
||||
|
||||
try:
|
||||
import ssl # python 2.6+
|
||||
except ImportError:
|
||||
|
@ -127,7 +133,7 @@ class _HTTPConnection(object):
|
|||
# Timeout handle returned by IOLoop.add_timeout
|
||||
self._timeout = None
|
||||
with stack_context.StackContext(self.cleanup):
|
||||
parsed = urlparse.urlsplit(self.request.url)
|
||||
parsed = urlparse.urlsplit(_unicode(self.request.url))
|
||||
host = parsed.hostname
|
||||
if parsed.port is None:
|
||||
port = 443 if parsed.scheme == "https" else 80
|
||||
|
@ -196,16 +202,16 @@ class _HTTPConnection(object):
|
|||
username = self.request.auth_username
|
||||
password = self.request.auth_password
|
||||
if username is not None:
|
||||
auth = "%s:%s" % (username, password)
|
||||
self.request.headers["Authorization"] = ("Basic %s" %
|
||||
auth = utf8(username) + b(":") + utf8(password)
|
||||
self.request.headers["Authorization"] = (b("Basic ") +
|
||||
base64.b64encode(auth))
|
||||
if self.request.user_agent:
|
||||
self.request.headers["User-Agent"] = self.request.user_agent
|
||||
has_body = self.request.method in ("POST", "PUT")
|
||||
if has_body:
|
||||
assert self.request.body is not None
|
||||
self.request.headers["Content-Length"] = len(
|
||||
self.request.body)
|
||||
self.request.headers["Content-Length"] = str(len(
|
||||
self.request.body))
|
||||
else:
|
||||
assert self.request.body is None
|
||||
if (self.request.method == "POST" and
|
||||
|
@ -215,17 +221,17 @@ class _HTTPConnection(object):
|
|||
self.request.headers["Accept-Encoding"] = "gzip"
|
||||
req_path = ((parsed.path or '/') +
|
||||
(('?' + parsed.query) if parsed.query else ''))
|
||||
request_lines = ["%s %s HTTP/1.1" % (self.request.method,
|
||||
req_path)]
|
||||
request_lines = [utf8("%s %s HTTP/1.1" % (self.request.method,
|
||||
req_path))]
|
||||
for k, v in self.request.headers.get_all():
|
||||
line = "%s: %s" % (k, v)
|
||||
if '\n' in line:
|
||||
line = utf8(k) + b(": ") + utf8(v)
|
||||
if b('\n') in line:
|
||||
raise ValueError('Newline in header: ' + repr(line))
|
||||
request_lines.append(line)
|
||||
self.stream.write("\r\n".join(request_lines) + "\r\n\r\n")
|
||||
self.stream.write(b("\r\n").join(request_lines) + b("\r\n\r\n"))
|
||||
if has_body:
|
||||
self.stream.write(self.request.body)
|
||||
self.stream.read_until("\r\n\r\n", self._on_headers)
|
||||
self.stream.write(utf8(self.request.body))
|
||||
self.stream.read_until(b("\r\n\r\n"), self._on_headers)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cleanup(self):
|
||||
|
@ -246,6 +252,7 @@ class _HTTPConnection(object):
|
|||
error=HTTPError(599, "Connection closed")))
|
||||
|
||||
def _on_headers(self, data):
|
||||
data = data.decode("latin1")
|
||||
first_line, _, header_data = data.partition("\r\n")
|
||||
match = re.match("HTTP/1.[01] ([0-9]+) .*", first_line)
|
||||
assert match
|
||||
|
@ -261,7 +268,7 @@ class _HTTPConnection(object):
|
|||
self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS)
|
||||
if self.headers.get("Transfer-Encoding") == "chunked":
|
||||
self.chunks = []
|
||||
self.stream.read_until("\r\n", self._on_chunk_length)
|
||||
self.stream.read_until(b("\r\n"), self._on_chunk_length)
|
||||
elif "Content-Length" in self.headers:
|
||||
self.stream.read_bytes(int(self.headers["Content-Length"]),
|
||||
self._on_body)
|
||||
|
@ -280,9 +287,9 @@ class _HTTPConnection(object):
|
|||
# if chunks is not None, we already called streaming_callback
|
||||
# in _on_chunk_data
|
||||
self.request.streaming_callback(data)
|
||||
buffer = StringIO()
|
||||
buffer = BytesIO()
|
||||
else:
|
||||
buffer = StringIO(data) # TODO: don't require one big string?
|
||||
buffer = BytesIO(data) # TODO: don't require one big string?
|
||||
original_request = getattr(self.request, "original_request",
|
||||
self.request)
|
||||
if (self.request.follow_redirects and
|
||||
|
@ -290,7 +297,7 @@ class _HTTPConnection(object):
|
|||
self.code in (301, 302)):
|
||||
new_request = copy.copy(self.request)
|
||||
new_request.url = urlparse.urljoin(self.request.url,
|
||||
self.headers["Location"])
|
||||
utf8(self.headers["Location"]))
|
||||
new_request.max_redirects -= 1
|
||||
del new_request.headers["Host"]
|
||||
new_request.original_request = original_request
|
||||
|
@ -311,13 +318,13 @@ class _HTTPConnection(object):
|
|||
# all the data has been decompressed, so we don't need to
|
||||
# decompress again in _on_body
|
||||
self._decompressor = None
|
||||
self._on_body(''.join(self.chunks))
|
||||
self._on_body(b('').join(self.chunks))
|
||||
else:
|
||||
self.stream.read_bytes(length + 2, # chunk ends with \r\n
|
||||
self._on_chunk_data)
|
||||
|
||||
def _on_chunk_data(self, data):
|
||||
assert data[-2:] == "\r\n"
|
||||
assert data[-2:] == b("\r\n")
|
||||
chunk = data[:-2]
|
||||
if self._decompressor:
|
||||
chunk = self._decompressor.decompress(chunk)
|
||||
|
@ -325,7 +332,7 @@ class _HTTPConnection(object):
|
|||
self.request.streaming_callback(chunk)
|
||||
else:
|
||||
self.chunks.append(chunk)
|
||||
self.stream.read_until("\r\n", self._on_chunk_length)
|
||||
self.stream.read_until(b("\r\n"), self._on_chunk_length)
|
||||
|
||||
|
||||
# match_hostname was added to the standard library ssl module in python 3.2.
|
||||
|
|
|
@ -48,14 +48,15 @@ Example usage:
|
|||
|
||||
from __future__ import with_statement
|
||||
|
||||
from types import NoneType
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
class _State(threading.local):
|
||||
def __init__(self):
|
||||
self.contexts = ()
|
||||
|
@ -172,7 +173,7 @@ def wrap(fn):
|
|||
new_contexts = [cls(arg)
|
||||
for (cls, arg) in contexts[len(_state.contexts):]]
|
||||
if len(new_contexts) > 1:
|
||||
with contextlib.nested(*new_contexts):
|
||||
with _nested(*new_contexts):
|
||||
callback(*args, **kwargs)
|
||||
elif new_contexts:
|
||||
with new_contexts[0]:
|
||||
|
@ -183,3 +184,38 @@ def wrap(fn):
|
|||
return fn
|
||||
return _StackContextWrapper(wrapped, fn, _state.contexts)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _nested(*managers):
|
||||
"""Support multiple context managers in a single with-statement.
|
||||
|
||||
Copied from the python 2.6 standard library. It's no longer present
|
||||
in python 3 because the with statement natively supports multiple
|
||||
context managers, but that doesn't help if the list of context
|
||||
managers is not known until runtime.
|
||||
"""
|
||||
exits = []
|
||||
vars = []
|
||||
exc = (None, None, None)
|
||||
try:
|
||||
for mgr in managers:
|
||||
exit = mgr.__exit__
|
||||
enter = mgr.__enter__
|
||||
vars.append(enter())
|
||||
exits.append(exit)
|
||||
yield vars
|
||||
except:
|
||||
exc = sys.exc_info()
|
||||
finally:
|
||||
while exits:
|
||||
exit = exits.pop()
|
||||
try:
|
||||
if exit(*exc):
|
||||
exc = (None, None, None)
|
||||
except:
|
||||
exc = sys.exc_info()
|
||||
if exc != (None, None, None):
|
||||
# Don't rely on sys.exc_info() still containing
|
||||
# the right information. Another exception may
|
||||
# have been raised and caught by an exit method
|
||||
raise exc[0], exc[1], exc[2]
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from tornado.iostream import IOStream
|
||||
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
|
||||
from tornado.util import b
|
||||
from tornado.web import RequestHandler, Application
|
||||
import socket
|
||||
|
||||
|
@ -15,22 +16,22 @@ class TestIOStream(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
s.connect(("localhost", self.get_http_port()))
|
||||
self.stream = IOStream(s, io_loop=self.io_loop)
|
||||
self.stream.write("GET / HTTP/1.0\r\n\r\n")
|
||||
self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))
|
||||
|
||||
# normal read
|
||||
self.stream.read_bytes(9, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "HTTP/1.0 ")
|
||||
self.assertEqual(data, b("HTTP/1.0 "))
|
||||
|
||||
# zero bytes
|
||||
self.stream.read_bytes(0, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "")
|
||||
self.assertEqual(data, b(""))
|
||||
|
||||
# another normal read
|
||||
self.stream.read_bytes(3, self.stop)
|
||||
data = self.wait()
|
||||
self.assertEqual(data, "200")
|
||||
self.assertEqual(data, b("200"))
|
||||
|
||||
def test_connection_refused(self):
|
||||
# When a connection is refused, the connect callback should not
|
||||
|
|
|
@ -12,6 +12,7 @@ from contextlib import closing
|
|||
from tornado.ioloop import IOLoop
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS
|
||||
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
|
||||
from tornado.util import b
|
||||
from tornado.web import Application, RequestHandler, asynchronous, url
|
||||
|
||||
class HelloWorldHandler(RequestHandler):
|
||||
|
@ -84,10 +85,10 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
response = self.fetch("/hello")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.headers["Content-Type"], "text/plain")
|
||||
self.assertEqual(response.body, "Hello world!")
|
||||
self.assertEqual(response.body, b("Hello world!"))
|
||||
|
||||
response = self.fetch("/hello?name=Ben")
|
||||
self.assertEqual(response.body, "Hello Ben!")
|
||||
self.assertEqual(response.body, b("Hello Ben!"))
|
||||
|
||||
def test_streaming_callback(self):
|
||||
# streaming_callback is also tested in test_chunked
|
||||
|
@ -95,29 +96,29 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
response = self.fetch("/hello",
|
||||
streaming_callback=chunks.append)
|
||||
# with streaming_callback, data goes to the callback and not response.body
|
||||
self.assertEqual(chunks, ["Hello world!"])
|
||||
self.assertEqual(chunks, [b("Hello world!")])
|
||||
self.assertFalse(response.body)
|
||||
|
||||
def test_post(self):
|
||||
response = self.fetch("/post", method="POST",
|
||||
body="arg1=foo&arg2=bar")
|
||||
self.assertEqual(response.code, 200)
|
||||
self.assertEqual(response.body, "Post arg1: foo, arg2: bar")
|
||||
self.assertEqual(response.body, b("Post arg1: foo, arg2: bar"))
|
||||
|
||||
def test_chunked(self):
|
||||
response = self.fetch("/chunk")
|
||||
self.assertEqual(response.body, "asdfqwer")
|
||||
self.assertEqual(response.body, b("asdfqwer"))
|
||||
|
||||
chunks = []
|
||||
response = self.fetch("/chunk",
|
||||
streaming_callback=chunks.append)
|
||||
self.assertEqual(chunks, ["asdf", "qwer"])
|
||||
self.assertEqual(chunks, [b("asdf"), b("qwer")])
|
||||
self.assertFalse(response.body)
|
||||
|
||||
def test_basic_auth(self):
|
||||
self.assertEqual(self.fetch("/auth", auth_username="Aladdin",
|
||||
auth_password="open sesame").body,
|
||||
"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
|
||||
b("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="))
|
||||
|
||||
def test_gzip(self):
|
||||
# All the tests in this file should be using gzip, but this test
|
||||
|
@ -127,11 +128,11 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
response = self.fetch("/chunk", use_gzip=False,
|
||||
headers={"Accept-Encoding": "gzip"})
|
||||
self.assertEqual(response.headers["Content-Encoding"], "gzip")
|
||||
self.assertNotEqual(response.body, "asdfqwer")
|
||||
self.assertNotEqual(response.body, b("asdfqwer"))
|
||||
# Our test data gets bigger when gzipped. Oops. :)
|
||||
self.assertEqual(len(response.body), 34)
|
||||
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
|
||||
self.assertEqual(f.read(), "asdfqwer")
|
||||
self.assertEqual(f.read(), b("asdfqwer"))
|
||||
|
||||
def test_connect_timeout(self):
|
||||
# create a socket and bind it to a port, but don't
|
||||
|
@ -203,25 +204,25 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
|
||||
response = self.fetch("/countdown/2")
|
||||
self.assertEqual(200, response.code)
|
||||
self.assertTrue(response.effective_url.endswith("/countdown/0"))
|
||||
self.assertEqual("Zero", response.body)
|
||||
self.assertTrue(response.effective_url.endswith(b("/countdown/0")))
|
||||
self.assertEqual(b("Zero"), response.body)
|
||||
|
||||
def test_max_redirects(self):
|
||||
response = self.fetch("/countdown/5", max_redirects=3)
|
||||
self.assertEqual(302, response.code)
|
||||
# We requested 5, followed three redirects for 4, 3, 2, then the last
|
||||
# unfollowed redirect is to 1.
|
||||
self.assertTrue(response.request.url.endswith("/countdown/5"))
|
||||
self.assertTrue(response.effective_url.endswith("/countdown/2"))
|
||||
self.assertTrue(response.request.url.endswith(b("/countdown/5")))
|
||||
self.assertTrue(response.effective_url.endswith(b("/countdown/2")))
|
||||
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
|
||||
|
||||
def test_default_certificates_exist(self):
|
||||
open(_DEFAULT_CA_CERTS)
|
||||
open(_DEFAULT_CA_CERTS).close()
|
||||
|
||||
def test_credentials_in_url(self):
|
||||
url = self.get_url("/auth").replace("http://", "http://me:secret@")
|
||||
self.http_client.fetch(url, self.stop)
|
||||
response = self.wait()
|
||||
self.assertEqual("Basic " + base64.b64encode("me:secret"),
|
||||
self.assertEqual(b("Basic ") + base64.b64encode(b("me:secret")),
|
||||
response.body)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import with_statement
|
|||
|
||||
from tornado.stack_context import StackContext, wrap
|
||||
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, LogTrapTestCase
|
||||
from tornado.util import b
|
||||
from tornado.web import asynchronous, Application, RequestHandler
|
||||
import contextlib
|
||||
import functools
|
||||
|
@ -45,8 +46,8 @@ class HTTPStackContextTest(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
def test_stack_context(self):
|
||||
self.http_client.fetch(self.get_url('/'), self.handle_response)
|
||||
self.wait()
|
||||
self.assertEquals(self.response.code, 500)
|
||||
self.assertTrue('got expected exception' in self.response.body)
|
||||
self.assertEqual(self.response.code, 500)
|
||||
self.assertTrue(b('got expected exception') in self.response.body)
|
||||
|
||||
def handle_response(self, response):
|
||||
self.response = response
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from tornado.escape import json_decode
|
||||
from tornado.escape import json_decode, utf8
|
||||
from tornado.iostream import IOStream
|
||||
from tornado.testing import LogTrapTestCase, AsyncHTTPTestCase
|
||||
from tornado.util import b
|
||||
from tornado.web import RequestHandler, _O, authenticated, Application, asynchronous
|
||||
|
||||
import binascii
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
|
@ -24,15 +26,15 @@ class CookieTestRequestHandler(RequestHandler):
|
|||
class SecureCookieTest(LogTrapTestCase):
|
||||
def test_round_trip(self):
|
||||
handler = CookieTestRequestHandler()
|
||||
handler.set_secure_cookie('foo', 'bar')
|
||||
self.assertEquals(handler.get_secure_cookie('foo'), 'bar')
|
||||
handler.set_secure_cookie('foo', b('bar'))
|
||||
self.assertEqual(handler.get_secure_cookie('foo'), b('bar'))
|
||||
|
||||
def test_cookie_tampering_future_timestamp(self):
|
||||
handler = CookieTestRequestHandler()
|
||||
# this string base64-encodes to '12345678'
|
||||
handler.set_secure_cookie('foo', '\xd7m\xf8\xe7\xae\xfc')
|
||||
handler.set_secure_cookie('foo', binascii.a2b_hex(b('d76df8e7aefc')))
|
||||
cookie = handler._cookies['foo']
|
||||
match = re.match(r'12345678\|([0-9]+)\|([0-9a-f]+)', cookie)
|
||||
match = re.match(b(r'12345678\|([0-9]+)\|([0-9a-f]+)'), cookie)
|
||||
assert match
|
||||
timestamp = match.group(1)
|
||||
sig = match.group(2)
|
||||
|
@ -41,10 +43,11 @@ class SecureCookieTest(LogTrapTestCase):
|
|||
# shifting digits from payload to timestamp doesn't alter signature
|
||||
# (this is not desirable behavior, just confirming that that's how it
|
||||
# works)
|
||||
self.assertEqual(handler._cookie_signature('foo', '1234',
|
||||
'5678' + timestamp), sig)
|
||||
self.assertEqual(
|
||||
handler._cookie_signature('foo', '1234', b('5678') + timestamp),
|
||||
sig)
|
||||
# tamper with the cookie
|
||||
handler._cookies['foo'] = '1234|5678%s|%s' % (timestamp, sig)
|
||||
handler._cookies['foo'] = utf8('1234|5678%s|%s' % (timestamp, sig))
|
||||
# it gets rejected
|
||||
assert handler.get_secure_cookie('foo') is None
|
||||
|
||||
|
@ -104,7 +107,7 @@ class ConnectionCloseTest(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
|
||||
s.connect(("localhost", self.get_http_port()))
|
||||
self.stream = IOStream(s, io_loop=self.io_loop)
|
||||
self.stream.write("GET / HTTP/1.0\r\n\r\n")
|
||||
self.stream.write(b("GET / HTTP/1.0\r\n\r\n"))
|
||||
self.wait()
|
||||
|
||||
def on_handler_waiting(self):
|
||||
|
|
|
@ -15,6 +15,21 @@ def import_object(name):
|
|||
obj = __import__('.'.join(parts[:-1]), None, None, [parts[-1]], 0)
|
||||
return getattr(obj, parts[-1])
|
||||
|
||||
# Fake byte literal support: In python 2.6+, you can say b"foo" to get
|
||||
# a byte literal (str in 2.x, bytes in 3.x). There's no way to do this
|
||||
# in a way that supports 2.5, though, so we need a function wrapper
|
||||
# to convert our string literals. b() should only be applied to literal
|
||||
# ascii strings. Once we drop support for 2.5, we can remove this function
|
||||
# and just use byte literals.
|
||||
if str is unicode:
|
||||
def b(s):
|
||||
return s.encode('ascii')
|
||||
bytes_type = bytes
|
||||
else:
|
||||
def b(s):
|
||||
return s
|
||||
bytes_type = str
|
||||
|
||||
def doctests():
|
||||
import doctest
|
||||
return doctest.DocTestSuite()
|
||||
|
|
|
@ -54,7 +54,6 @@ from __future__ import with_statement
|
|||
import Cookie
|
||||
import base64
|
||||
import binascii
|
||||
import cStringIO
|
||||
import calendar
|
||||
import contextlib
|
||||
import datetime
|
||||
|
@ -81,6 +80,13 @@ from tornado import escape
|
|||
from tornado import locale
|
||||
from tornado import stack_context
|
||||
from tornado import template
|
||||
from tornado.escape import utf8
|
||||
from tornado.util import b, bytes_type
|
||||
|
||||
try:
|
||||
from io import BytesIO # python 3
|
||||
except ImportError:
|
||||
from cStringIO import StringIO as BytesIO # python 2
|
||||
|
||||
class RequestHandler(object):
|
||||
"""Subclass this class and define get() or post() to make a handler.
|
||||
|
@ -206,11 +212,11 @@ class RequestHandler(object):
|
|||
elif isinstance(value, int) or isinstance(value, long):
|
||||
value = str(value)
|
||||
else:
|
||||
value = _utf8(value)
|
||||
value = utf8(value)
|
||||
# If \n is allowed into the header, it is possible to inject
|
||||
# additional headers or split the request. Also cap length to
|
||||
# prevent obviously erroneous values.
|
||||
safe_value = re.sub(r"[\x00-\x1f]", " ", value)[:4000]
|
||||
safe_value = re.sub(b(r"[\x00-\x1f]"), b(" "), value)[:4000]
|
||||
if safe_value != value:
|
||||
raise ValueError("Unsafe header value %r", value)
|
||||
self._headers[name] = value
|
||||
|
@ -243,8 +249,8 @@ class RequestHandler(object):
|
|||
"""
|
||||
values = self.request.arguments.get(name, [])
|
||||
# Get rid of any weird control chars
|
||||
values = [re.sub(r"[\x00-\x08\x0e-\x1f]", " ", x) for x in values]
|
||||
values = [_unicode(x) for x in values]
|
||||
values = [re.sub(r"[\x00-\x08\x0e-\x1f]", " ", _unicode(x))
|
||||
for x in values]
|
||||
if strip:
|
||||
values = [x.strip() for x in values]
|
||||
return values
|
||||
|
@ -277,8 +283,8 @@ class RequestHandler(object):
|
|||
See http://docs.python.org/library/cookie.html#morsel-objects
|
||||
for available attributes.
|
||||
"""
|
||||
name = _utf8(name)
|
||||
value = _utf8(value)
|
||||
name = utf8(name)
|
||||
value = utf8(value)
|
||||
if re.search(r"[\x00-\x20]", name + value):
|
||||
# Don't let us accidentally inject bad stuff
|
||||
raise ValueError("Invalid cookie %r: %r" % (name, value))
|
||||
|
@ -331,10 +337,10 @@ class RequestHandler(object):
|
|||
method for non-cookie uses. To decode a value not stored
|
||||
as a cookie use the optional value argument to get_secure_cookie.
|
||||
"""
|
||||
timestamp = str(int(time.time()))
|
||||
value = base64.b64encode(value)
|
||||
timestamp = utf8(str(int(time.time())))
|
||||
value = base64.b64encode(utf8(value))
|
||||
signature = self._cookie_signature(name, value, timestamp)
|
||||
value = "|".join([value, timestamp, signature])
|
||||
value = b("|").join([value, timestamp, signature])
|
||||
return value
|
||||
|
||||
def get_secure_cookie(self, name, include_name=True, value=None):
|
||||
|
@ -349,7 +355,7 @@ class RequestHandler(object):
|
|||
"""
|
||||
if value is None: value = self.get_cookie(name)
|
||||
if not value: return None
|
||||
parts = value.split("|")
|
||||
parts = utf8(value).split(b("|"))
|
||||
if len(parts) != 3: return None
|
||||
if include_name:
|
||||
signature = self._cookie_signature(name, parts[0], parts[1])
|
||||
|
@ -370,7 +376,7 @@ class RequestHandler(object):
|
|||
# here instead of modifying _cookie_signature.
|
||||
logging.warning("Cookie timestamp in future; possible tampering %r", value)
|
||||
return None
|
||||
if parts[1].startswith("0"):
|
||||
if parts[1].startswith(b("0")):
|
||||
logging.warning("Tampered cookie %r", value)
|
||||
try:
|
||||
return base64.b64decode(parts[0])
|
||||
|
@ -379,10 +385,10 @@ class RequestHandler(object):
|
|||
|
||||
def _cookie_signature(self, *parts):
|
||||
self.require_setting("cookie_secret", "secure cookies")
|
||||
hash = hmac.new(self.application.settings["cookie_secret"],
|
||||
hash = hmac.new(utf8(self.application.settings["cookie_secret"]),
|
||||
digestmod=hashlib.sha1)
|
||||
for part in parts: hash.update(part)
|
||||
return hash.hexdigest()
|
||||
for part in parts: hash.update(utf8(part))
|
||||
return utf8(hash.hexdigest())
|
||||
|
||||
def redirect(self, url, permanent=False):
|
||||
"""Sends a redirect to the given (optionally relative) URL."""
|
||||
|
@ -390,8 +396,9 @@ class RequestHandler(object):
|
|||
raise Exception("Cannot redirect after headers have been written")
|
||||
self.set_status(301 if permanent else 302)
|
||||
# Remove whitespace
|
||||
url = re.sub(r"[\x00-\x20]+", "", _utf8(url))
|
||||
self.set_header("Location", urlparse.urljoin(self.request.uri, url))
|
||||
url = re.sub(b(r"[\x00-\x20]+"), "", utf8(url))
|
||||
self.set_header("Location", urlparse.urljoin(utf8(self.request.uri),
|
||||
url))
|
||||
self.finish()
|
||||
|
||||
def write(self, chunk):
|
||||
|
@ -411,7 +418,7 @@ class RequestHandler(object):
|
|||
if isinstance(chunk, dict):
|
||||
chunk = escape.json_encode(chunk)
|
||||
self.set_header("Content-Type", "text/javascript; charset=UTF-8")
|
||||
chunk = _utf8(chunk)
|
||||
chunk = utf8(chunk)
|
||||
self._write_buffer.append(chunk)
|
||||
|
||||
def render(self, template_name, **kwargs):
|
||||
|
@ -427,7 +434,7 @@ class RequestHandler(object):
|
|||
html_bodies = []
|
||||
for module in getattr(self, "_active_modules", {}).itervalues():
|
||||
embed_part = module.embedded_javascript()
|
||||
if embed_part: js_embed.append(_utf8(embed_part))
|
||||
if embed_part: js_embed.append(utf8(embed_part))
|
||||
file_part = module.javascript_files()
|
||||
if file_part:
|
||||
if isinstance(file_part, basestring):
|
||||
|
@ -435,7 +442,7 @@ class RequestHandler(object):
|
|||
else:
|
||||
js_files.extend(file_part)
|
||||
embed_part = module.embedded_css()
|
||||
if embed_part: css_embed.append(_utf8(embed_part))
|
||||
if embed_part: css_embed.append(utf8(embed_part))
|
||||
file_part = module.css_files()
|
||||
if file_part:
|
||||
if isinstance(file_part, basestring):
|
||||
|
@ -443,9 +450,9 @@ class RequestHandler(object):
|
|||
else:
|
||||
css_files.extend(file_part)
|
||||
head_part = module.html_head()
|
||||
if head_part: html_heads.append(_utf8(head_part))
|
||||
if head_part: html_heads.append(utf8(head_part))
|
||||
body_part = module.html_body()
|
||||
if body_part: html_bodies.append(_utf8(body_part))
|
||||
if body_part: html_bodies.append(utf8(body_part))
|
||||
def is_absolute(path):
|
||||
return any(path.startswith(x) for x in ["/", "http:", "https:"])
|
||||
if js_files:
|
||||
|
@ -535,7 +542,7 @@ class RequestHandler(object):
|
|||
if self.application._wsgi:
|
||||
raise Exception("WSGI applications do not support flush()")
|
||||
|
||||
chunk = "".join(self._write_buffer)
|
||||
chunk = b("").join(self._write_buffer)
|
||||
self._write_buffer = []
|
||||
if not self._headers_written:
|
||||
self._headers_written = True
|
||||
|
@ -546,7 +553,7 @@ class RequestHandler(object):
|
|||
else:
|
||||
for transform in self._transforms:
|
||||
chunk = transform.transform_chunk(chunk, include_footers)
|
||||
headers = ""
|
||||
headers = b("")
|
||||
|
||||
# Ignore the chunk and only write the headers for HEAD requests
|
||||
if self.request.method == "HEAD":
|
||||
|
@ -800,7 +807,7 @@ class RequestHandler(object):
|
|||
path)
|
||||
if abs_path not in hashes:
|
||||
try:
|
||||
f = open(abs_path)
|
||||
f = open(abs_path, "rb")
|
||||
hashes[abs_path] = hashlib.md5(f.read()).hexdigest()
|
||||
f.close()
|
||||
except:
|
||||
|
@ -873,13 +880,14 @@ class RequestHandler(object):
|
|||
self.finish()
|
||||
|
||||
def _generate_headers(self):
|
||||
lines = [self.request.version + " " + str(self._status_code) + " " +
|
||||
httplib.responses[self._status_code]]
|
||||
lines.extend(["%s: %s" % (n, v) for n, v in self._headers.iteritems()])
|
||||
lines = [utf8(self.request.version + " " +
|
||||
str(self._status_code) +
|
||||
" " + httplib.responses[self._status_code])]
|
||||
lines.extend([(utf8(n) + b(": ") + utf8(v)) for n, v in self._headers.iteritems()])
|
||||
for cookie_dict in getattr(self, "_new_cookies", []):
|
||||
for cookie in cookie_dict.values():
|
||||
lines.append("Set-Cookie: " + cookie.OutputString(None))
|
||||
return "\r\n".join(lines) + "\r\n\r\n"
|
||||
lines.append(b("Set-Cookie: ") + cookie.OutputString(None))
|
||||
return b("\r\n").join(lines) + b("\r\n\r\n")
|
||||
|
||||
def _log(self):
|
||||
"""Logs the current request.
|
||||
|
@ -891,8 +899,8 @@ class RequestHandler(object):
|
|||
self.application.log_request(self)
|
||||
|
||||
def _request_summary(self):
|
||||
return self.request.method + " " + self.request.uri + " (" + \
|
||||
self.request.remote_ip + ")"
|
||||
return self.request.method + " " + self.request.uri + \
|
||||
" (" + self.request.remote_ip + ")"
|
||||
|
||||
def _handle_request_exception(self, e):
|
||||
if isinstance(e, HTTPError):
|
||||
|
@ -1429,14 +1437,14 @@ class GZipContentEncoding(OutputTransform):
|
|||
|
||||
def transform_first_chunk(self, headers, chunk, finishing):
|
||||
if self._gzipping:
|
||||
ctype = headers.get("Content-Type", "").split(";")[0]
|
||||
ctype = _unicode(headers.get("Content-Type", "")).split(";")[0]
|
||||
self._gzipping = (ctype in self.CONTENT_TYPES) and \
|
||||
(not finishing or len(chunk) >= self.MIN_LENGTH) and \
|
||||
(finishing or "Content-Length" not in headers) and \
|
||||
("Content-Encoding" not in headers)
|
||||
if self._gzipping:
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
self._gzip_value = cStringIO.StringIO()
|
||||
self._gzip_value = BytesIO()
|
||||
self._gzip_file = gzip.GzipFile(mode="w", fileobj=self._gzip_value)
|
||||
self._gzip_pos = 0
|
||||
chunk = self.transform_chunk(chunk, finishing)
|
||||
|
@ -1481,9 +1489,9 @@ class ChunkedTransferEncoding(OutputTransform):
|
|||
# Don't write out empty chunks because that means END-OF-STREAM
|
||||
# with chunked encoding
|
||||
if block:
|
||||
block = ("%x" % len(block)) + "\r\n" + block + "\r\n"
|
||||
block = utf8("%x" % len(block)) + b("\r\n") + block + b("\r\n")
|
||||
if finishing:
|
||||
block += "0\r\n\r\n"
|
||||
block += b("0\r\n\r\n")
|
||||
return block
|
||||
|
||||
|
||||
|
@ -1614,15 +1622,9 @@ class URLSpec(object):
|
|||
|
||||
url = URLSpec
|
||||
|
||||
def _utf8(s):
|
||||
if isinstance(s, unicode):
|
||||
return s.encode("utf-8")
|
||||
assert isinstance(s, str)
|
||||
return s
|
||||
|
||||
|
||||
def _unicode(s):
|
||||
if isinstance(s, str):
|
||||
if isinstance(s, bytes_type):
|
||||
try:
|
||||
return s.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
|
@ -1635,8 +1637,12 @@ def _time_independent_equals(a, b):
|
|||
if len(a) != len(b):
|
||||
return False
|
||||
result = 0
|
||||
for x, y in zip(a, b):
|
||||
result |= ord(x) ^ ord(y)
|
||||
if type(a[0]) is int: # python3 byte strings
|
||||
for x, y in zip(a,b):
|
||||
result |= x ^ y
|
||||
else: # python2
|
||||
for x, y in zip(a, b):
|
||||
result |= ord(x) ^ ord(y)
|
||||
return result == 0
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue