diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index 2c657a8a..86a6c0dd 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -14,6 +14,7 @@ import functools import logging import re import socket +import time import urlparse import zlib @@ -63,6 +64,7 @@ class _HTTPConnection(object): _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"]) def __init__(self, io_loop, request, callback): + self.start_time = time.time() self.io_loop = io_loop self.request = request self.callback = callback @@ -70,6 +72,8 @@ class _HTTPConnection(object): self.headers = None self.chunks = None self._decompressor = None + # Timeout handle returned by IOLoop.add_timeout + self._timeout = None with stack_context.StackContext(self.cleanup): parsed = urlparse.urlsplit(self.request.url) if ":" in parsed.netloc: @@ -86,10 +90,30 @@ class _HTTPConnection(object): else: self.stream = IOStream(socket.socket(), io_loop=self.io_loop) + timeout = min(request.connect_timeout, request.request_timeout) + if timeout: + self._connect_timeout = self.io_loop.add_timeout( + self.start_time + timeout, + self._on_timeout) self.stream.connect((host, port), functools.partial(self._on_connect, parsed)) + def _on_timeout(self): + self._timeout = None + self.stream.close() + if self.callback is not None: + self.callback(HTTPResponse(self.request, 599, + error=HTTPError(599, "Timeout"))) + self.callback = None + def _on_connect(self, parsed): + if self._timeout is not None: + self.io_loop.remove_callback(self._timeout) + self._timeout = None + if self.request.request_timeout: + self._timeout = self.io_loop.add_timeout( + self.start_time + self.request.request_timeout, + self._on_timeout) if (self.request.method not in self._SUPPORTED_METHODS and not self.request.allow_nonstandard_methods): raise KeyError("unknown method %s" % self.request.method) @@ -167,6 +191,9 @@ class _HTTPConnection(object): "don't know how to read %s", self.request.url) def _on_body(self, data): + if self._timeout is not None: + self.io_loop.remove_timeout(self._timeout) + self._timeout = None if self._decompressor: data = self._decompressor.decompress(data) if self.request.streaming_callback: diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index d72f01cc..c99e547a 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -2,10 +2,12 @@ import gzip import logging +import socket +from contextlib import closing from tornado.simple_httpclient import SimpleAsyncHTTPClient -from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase -from tornado.web import Application, RequestHandler +from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port +from tornado.web import Application, RequestHandler, asynchronous class HelloWorldHandler(RequestHandler): def get(self): @@ -28,6 +30,11 @@ class AuthHandler(RequestHandler): def get(self): self.finish(self.request.headers["Authorization"]) +class HangHandler(RequestHandler): + @asynchronous + def get(self): + pass + class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): def get_app(self): return Application([ @@ -35,6 +42,7 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): ("/post", PostHandler), ("/chunk", ChunkHandler), ("/auth", AuthHandler), + ("/hang", HangHandler), ], gzip=True) def setUp(self): @@ -94,3 +102,23 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): self.assertEqual(len(response.body), 34) f = gzip.GzipFile(mode="r", fileobj=response.buffer) self.assertEqual(f.read(), "asdfqwer") + + def test_connect_timeout(self): + # create a socket and bind it to a port, but don't + # call accept so the connection will timeout. + #get_unused_port() + port = get_unused_port() + + with closing(socket.socket()) as sock: + sock.bind(('', port)) + self.http_client.fetch("http://localhost:%d/" % port, + self.stop, + connect_timeout=0.1) + response = self.wait() + self.assertEqual(response.code, 599) + self.assertEqual(str(response.error), "HTTP 599: Timeout") + + def test_request_timeout(self): + response = self.fetch('/hang', request_timeout=0.1) + self.assertEqual(response.code, 599) + self.assertEqual(str(response.error), "HTTP 599: Timeout")