diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 45d4c747..2c657a8a 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -15,6 +15,7 @@ import logging import re import socket import urlparse +import zlib try: import ssl # python 2.6+ @@ -68,6 +69,7 @@ class _HTTPConnection(object): self.code = None self.headers = None self.chunks = None + self._decompressor = None with stack_context.StackContext(self.cleanup): parsed = urlparse.urlsplit(self.request.url) if ":" in parsed.netloc: @@ -113,6 +115,8 @@ class _HTTPConnection(object): if (self.request.method == "POST" and "Content-Type" not in self.request.headers): self.request.headers["Content-Type"] = "application/x-www-form-urlencoded" + if self.request.use_gzip: + self.request.headers["Accept-Encoding"] = "gzip" req_path = ((parsed.path or '/') + (('?' + parsed.query) if parsed.query else '')) request_lines = ["%s %s HTTP/1.1" % (self.request.method, @@ -147,6 +151,11 @@ class _HTTPConnection(object): if self.request.header_callback is not None: for k, v in self.headers.get_all(): self.request.header_callback("%s: %s\r\n" % (k, v)) + if (self.request.use_gzip and + self.headers.get("Content-Encoding") == "gzip"): + # Magic parameter makes zlib module understand gzip header + # http://stackoverflow.com/questions/1838699/how-can-i-decompress-a-gzip-stream-with-zlib + self._decompressor = zlib.decompressobj(16+zlib.MAX_WBITS) if self.headers.get("Transfer-Encoding") == "chunked": self.chunks = [] self.stream.read_until("\r\n", self._on_chunk_length) @@ -158,6 +167,8 @@ class _HTTPConnection(object): "don't know how to read %s", self.request.url) def _on_body(self, data): + if self._decompressor: + data = self._decompressor.decompress(data) if self.request.streaming_callback: if self.chunks is None: # if chunks is not None, we already called streaming_callback @@ -175,6 +186,9 @@ class _HTTPConnection(object): # TODO: "chunk extensions" http://tools.ietf.org/html/rfc2616#section-3.6.1 length = int(data.strip(), 16) if length == 0: + # all the data has been decompressed, so we don't need to + # decompress again in _on_body + self._decompressor = None self._on_body(''.join(self.chunks)) else: self.stream.read_bytes(length + 2, # chunk ends with \r\n @@ -183,6 +197,8 @@ class _HTTPConnection(object): def _on_chunk_data(self, data): assert data[-2:] == "\r\n" chunk = data[:-2] + if self._decompressor: + chunk = self._decompressor.decompress(chunk) if self.request.streaming_callback is not None: self.request.streaming_callback(chunk) else: diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index 1bece2a3..d72f01cc 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -1,5 +1,8 @@ #!/usr/bin/env python +import gzip +import logging + from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase from tornado.web import Application, RequestHandler @@ -32,7 +35,7 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): ("/post", PostHandler), ("/chunk", ChunkHandler), ("/auth", AuthHandler), - ]) + ], gzip=True) def setUp(self): super(SimpleHTTPClientTestCase, self).setUp() @@ -77,3 +80,17 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): self.assertEqual(self.fetch("/auth", auth_username="Aladdin", auth_password="open sesame").body, "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + + def test_gzip(self): + # All the tests in this file should be using gzip, but this test + # ensures that it is in fact getting compressed. + # Setting Accept-Encoding manually bypasses the client's + # decompression so we can see the raw data. + response = self.fetch("/chunk", use_gzip=False, + headers={"Accept-Encoding": "gzip"}) + self.assertEqual(response.headers["Content-Encoding"], "gzip") + self.assertNotEqual(response.body, "asdfqwer") + # Our test data gets bigger when gzipped. Oops. :) + self.assertEqual(len(response.body), 34) + f = gzip.GzipFile(mode="r", fileobj=response.buffer) + self.assertEqual(f.read(), "asdfqwer")