From 6dbce939982cfac185dab8fcd25bde14528660e1 Mon Sep 17 00:00:00 2001 From: Ben Darnell Date: Sun, 6 May 2012 18:10:13 -0700 Subject: [PATCH] Fix bug when max_clients kwarg is passed to AsyncHTTPClient.configure. Closes #493. --- tornado/httpclient.py | 21 +++++++++++-- tornado/test/simple_httpclient_test.py | 41 +++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/tornado/httpclient.py b/tornado/httpclient.py index 89f0057a..0fcc943f 100644 --- a/tornado/httpclient.py +++ b/tornado/httpclient.py @@ -124,6 +124,8 @@ class AsyncHTTPClient(object): _impl_class = None _impl_kwargs = None + _DEFAULT_MAX_CLIENTS = 10 + @classmethod def _async_clients(cls): assert cls is not AsyncHTTPClient, "should only be called on subclasses" @@ -131,7 +133,7 @@ class AsyncHTTPClient(object): cls._async_client_dict = weakref.WeakKeyDictionary() return cls._async_client_dict - def __new__(cls, io_loop=None, max_clients=10, force_instance=False, + def __new__(cls, io_loop=None, max_clients=None, force_instance=False, **kwargs): io_loop = io_loop or IOLoop.instance() if cls is AsyncHTTPClient: @@ -149,7 +151,13 @@ class AsyncHTTPClient(object): if cls._impl_kwargs: args.update(cls._impl_kwargs) args.update(kwargs) - instance.initialize(io_loop, max_clients, **args) + if max_clients is not None: + # max_clients is special because it may be passed + # positionally instead of by keyword + args["max_clients"] = max_clients + elif "max_clients" not in args: + args["max_clients"] = AsyncHTTPClient._DEFAULT_MAX_CLIENTS + instance.initialize(io_loop, **args) if not force_instance: impl._async_clients()[io_loop] = instance return instance @@ -204,6 +212,15 @@ class AsyncHTTPClient(object): AsyncHTTPClient._impl_class = impl AsyncHTTPClient._impl_kwargs = kwargs + @staticmethod + def _save_configuration(): + return (AsyncHTTPClient._impl_class, AsyncHTTPClient._impl_kwargs) + + @staticmethod + def _restore_configuration(saved): + AsyncHTTPClient._impl_class = saved[0] + AsyncHTTPClient._impl_kwargs = saved[1] + class HTTPRequest(object): """HTTP client request object.""" diff --git a/tornado/test/simple_httpclient_test.py b/tornado/test/simple_httpclient_test.py index bbfd57b1..4a48eb0e 100644 --- a/tornado/test/simple_httpclient_test.py +++ b/tornado/test/simple_httpclient_test.py @@ -1,16 +1,18 @@ from __future__ import absolute_import, division, with_statement import collections +from contextlib import closing import gzip import logging import re import socket +from tornado.httpclient import AsyncHTTPClient from tornado.httputil import HTTPHeaders from tornado.ioloop import IOLoop from tornado.simple_httpclient import SimpleAsyncHTTPClient, _DEFAULT_CA_CERTS from tornado.test.httpclient_test import HTTPClientCommonTestCase, ChunkHandler, CountdownHandler, HelloWorldHandler -from tornado.testing import AsyncHTTPTestCase, LogTrapTestCase +from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, LogTrapTestCase from tornado.util import b from tornado.web import RequestHandler, Application, asynchronous, url @@ -263,3 +265,40 @@ class SimpleHTTPClientTestCase(AsyncHTTPTestCase, LogTrapTestCase): self.http_client.fetch(url, self.stop) response = self.wait() self.assertTrue(host_re.match(response.body), response.body) + + +class CreateAsyncHTTPClientTestCase(AsyncTestCase, LogTrapTestCase): + def setUp(self): + super(CreateAsyncHTTPClientTestCase, self).setUp() + self.saved = AsyncHTTPClient._save_configuration() + + def tearDown(self): + AsyncHTTPClient._restore_configuration(self.saved) + super(CreateAsyncHTTPClientTestCase, self).tearDown() + + def test_max_clients(self): + # The max_clients argument is tricky because it was originally + # allowed to be passed positionally; newer arguments are keyword-only. + AsyncHTTPClient.configure(SimpleAsyncHTTPClient) + with closing(AsyncHTTPClient( + self.io_loop, force_instance=True)) as client: + self.assertEqual(client.max_clients, 10) + with closing(AsyncHTTPClient( + self.io_loop, 11, force_instance=True)) as client: + self.assertEqual(client.max_clients, 11) + with closing(AsyncHTTPClient( + self.io_loop, max_clients=11, force_instance=True)) as client: + self.assertEqual(client.max_clients, 11) + + # Now configure max_clients statically and try overriding it + # with each way max_clients can be passed + AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12) + with closing(AsyncHTTPClient( + self.io_loop, force_instance=True)) as client: + self.assertEqual(client.max_clients, 12) + with closing(AsyncHTTPClient( + self.io_loop, max_clients=13, force_instance=True)) as client: + self.assertEqual(client.max_clients, 13) + with closing(AsyncHTTPClient( + self.io_loop, max_clients=14, force_instance=True)) as client: + self.assertEqual(client.max_clients, 14)