From 79073dbef9070d87b2e22f2233ffd3b7c47b4c80 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 29 May 2011 18:39:23 -0700 Subject: [PATCH] Type checks for httpserver.HTTPRequest fields --- tornado/httpserver.py | 3 +- tornado/test/httpserver_test.py | 51 +++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/tornado/httpserver.py b/tornado/httpserver.py index 05f43f46..bcfdc78f 100644 --- a/tornado/httpserver.py +++ b/tornado/httpserver.py @@ -398,9 +398,8 @@ class HTTPConnection(object): content_type = self._request.headers.get("Content-Type", "") if self._request.method in ("POST", "PUT"): if content_type.startswith("application/x-www-form-urlencoded"): - arguments = parse_qs(self._request.body) + arguments = parse_qs(native_str(self._request.body)) for name, values in arguments.iteritems(): - name = name.decode('utf-8') values = [v for v in values if v] if values: self._request.arguments.setdefault(name, []).extend( diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 95216128..4600f77f 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -5,7 +5,7 @@ from tornado.escape import json_decode, utf8, _unicode from tornado.iostream import IOStream from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase -from tornado.util import b +from tornado.util import b, bytes_type from tornado.web import Application, RequestHandler import logging import os @@ -140,11 +140,58 @@ class EchoHandler(RequestHandler): def get(self): self.write(self.request.arguments) +class TypeCheckHandler(RequestHandler): + def prepare(self): + self.errors = {} + fields = [ + ('method', str), + ('uri', str), + ('version', str), + ('remote_ip', str), + ('protocol', str), + ('host', str), + ('path', str), + ('query', str), + ] + for field, expected_type in fields: + self.check_type(field, getattr(self.request, field), expected_type) + + self.check_type('header_key', self.request.headers.keys()[0], str) + self.check_type('header_value', self.request.headers.values()[0], str) + + self.check_type('arg_key', self.request.arguments.keys()[0], str) + self.check_type('arg_value', self.request.arguments.values()[0][0], str) + + def post(self): + self.check_type('body', self.request.body, bytes_type) + self.write(self.errors) + + def get(self): + self.write(self.errors) + + def check_type(self, name, obj, expected_type): + actual_type = type(obj) + if expected_type != actual_type: + self.errors[name] = "expected %s, got %s" % (expected_type, + actual_type) + class HTTPServerTest(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): - return Application([("/echo", EchoHandler)]) + return Application([("/echo", EchoHandler), + ("/typecheck", TypeCheckHandler), + ]) def test_query_string_encoding(self): response = self.fetch("/echo?foo=%C3%A9") data = json_decode(response.body) self.assertEqual(data, {u"foo": [u"\u00e9"]}) + + def test_types(self): + response = self.fetch("/typecheck?foo=bar") + data = json_decode(response.body) + self.assertEqual(data, {}) + + response = self.fetch("/typecheck", method="POST", body="foo=bar") + data = json_decode(response.body) + self.assertEqual(data, {}) +