Add timeout support to simple_httpclient
This commit is contained in:
parent
8a941c42b2
commit
f2aa302bcb
|
@ -14,6 +14,7 @@ import functools
|
|||
import logging
|
||||
import re
|
||||
import socket
|
||||
import time
|
||||
import urlparse
|
||||
import zlib
|
||||
|
||||
|
@ -63,6 +64,7 @@ class _HTTPConnection(object):
|
|||
_SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE"])
|
||||
|
||||
def __init__(self, io_loop, request, callback):
|
||||
self.start_time = time.time()
|
||||
self.io_loop = io_loop
|
||||
self.request = request
|
||||
self.callback = callback
|
||||
|
@ -70,6 +72,8 @@ class _HTTPConnection(object):
|
|||
self.headers = None
|
||||
self.chunks = None
|
||||
self._decompressor = None
|
||||
# Timeout handle returned by IOLoop.add_timeout
|
||||
self._timeout = None
|
||||
with stack_context.StackContext(self.cleanup):
|
||||
parsed = urlparse.urlsplit(self.request.url)
|
||||
if ":" in parsed.netloc:
|
||||
|
@ -86,10 +90,30 @@ class _HTTPConnection(object):
|
|||
else:
|
||||
self.stream = IOStream(socket.socket(),
|
||||
io_loop=self.io_loop)
|
||||
timeout = min(request.connect_timeout, request.request_timeout)
|
||||
if timeout:
|
||||
self._connect_timeout = self.io_loop.add_timeout(
|
||||
self.start_time + timeout,
|
||||
self._on_timeout)
|
||||
self.stream.connect((host, port),
|
||||
functools.partial(self._on_connect, parsed))
|
||||
|
||||
def _on_timeout(self):
|
||||
self._timeout = None
|
||||
self.stream.close()
|
||||
if self.callback is not None:
|
||||
self.callback(HTTPResponse(self.request, 599,
|
||||
error=HTTPError(599, "Timeout")))
|
||||
self.callback = None
|
||||
|
||||
def _on_connect(self, parsed):
|
||||
if self._timeout is not None:
|
||||
self.io_loop.remove_callback(self._timeout)
|
||||
self._timeout = None
|
||||
if self.request.request_timeout:
|
||||
self._timeout = self.io_loop.add_timeout(
|
||||
self.start_time + self.request.request_timeout,
|
||||
self._on_timeout)
|
||||
if (self.request.method not in self._SUPPORTED_METHODS and
|
||||
not self.request.allow_nonstandard_methods):
|
||||
raise KeyError("unknown method %s" % self.request.method)
|
||||
|
@ -167,6 +191,9 @@ class _HTTPConnection(object):
|
|||
"don't know how to read %s", self.request.url)
|
||||
|
||||
def _on_body(self, data):
|
||||
if self._timeout is not None:
|
||||
self.io_loop.remove_timeout(self._timeout)
|
||||
self._timeout = None
|
||||
if self._decompressor:
|
||||
data = self._decompressor.decompress(data)
|
||||
if self.request.streaming_callback:
|
||||
|
|
|
@ -2,10 +2,12 @@
|
|||
|
||||
import gzip
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from contextlib import closing
|
||||
from tornado.simple_httpclient import SimpleAsyncHTTPClient
|
||||
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase
|
||||
from tornado.web import Application, RequestHandler
|
||||
from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase, get_unused_port
|
||||
from tornado.web import Application, RequestHandler, asynchronous
|
||||
|
||||
class HelloWorldHandler(RequestHandler):
|
||||
def get(self):
|
||||
|
@ -28,6 +30,11 @@ class AuthHandler(RequestHandler):
|
|||
def get(self):
|
||||
self.finish(self.request.headers["Authorization"])
|
||||
|
||||
class HangHandler(RequestHandler):
|
||||
@asynchronous
|
||||
def get(self):
|
||||
pass
|
||||
|
||||
class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
|
||||
def get_app(self):
|
||||
return Application([
|
||||
|
@ -35,6 +42,7 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
("/post", PostHandler),
|
||||
("/chunk", ChunkHandler),
|
||||
("/auth", AuthHandler),
|
||||
("/hang", HangHandler),
|
||||
], gzip=True)
|
||||
|
||||
def setUp(self):
|
||||
|
@ -94,3 +102,23 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase):
|
|||
self.assertEqual(len(response.body), 34)
|
||||
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
|
||||
self.assertEqual(f.read(), "asdfqwer")
|
||||
|
||||
def test_connect_timeout(self):
|
||||
# create a socket and bind it to a port, but don't
|
||||
# call accept so the connection will timeout.
|
||||
#get_unused_port()
|
||||
port = get_unused_port()
|
||||
|
||||
with closing(socket.socket()) as sock:
|
||||
sock.bind(('', port))
|
||||
self.http_client.fetch("http://localhost:%d/" % port,
|
||||
self.stop,
|
||||
connect_timeout=0.1)
|
||||
response = self.wait()
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(str(response.error), "HTTP 599: Timeout")
|
||||
|
||||
def test_request_timeout(self):
|
||||
response = self.fetch('/hang', request_timeout=0.1)
|
||||
self.assertEqual(response.code, 599)
|
||||
self.assertEqual(str(response.error), "HTTP 599: Timeout")
|
||||
|
|
Loading…
Reference in New Issue