100% unit test coverage for Redis transport

This commit is contained in:
Ask Solem 2010-11-11 11:19:29 +01:00
parent a77bfc93a6
commit 4e32fa3340
3 changed files with 286 additions and 21 deletions

View File

@ -0,0 +1,243 @@
import socket
import types
import unittest2 as unittest
from Queue import Empty, Queue as _Queue
from kombu.connection import BrokerConnection
from kombu.entity import Exchange, Queue
from kombu.messaging import Consumer, Producer
from kombu.transport import pyredis
from kombu.tests.utils import module_exists
class ResponseError(Exception):
pass
class Client(object):
def __init__(self, db=None, port=None, **kwargs):
self.port = port
self.db = db
self._called = []
self.queues = {}
self.sets = {}
self.bgsave_raises_ResponseError = False
def bgsave(self):
self._called.append("BGSAVE")
if self.bgsave_raises_ResponseError:
raise ResponseError()
def delete(self, key):
self.queues.pop(key, None)
def sadd(self, key, member):
if key not in self.sets:
self.sets[key] = set()
self.sets[key].add(member)
def smembers(self, key):
return self.sets.get(key, set())
def llen(self, key):
return self.queues[key].qsize()
def lpush(self, key, value):
self.queues[key].put_nowait(value)
def brpop(self, keys, timeout=None):
key = keys[0]
try:
item = self.queues[key].get(timeout=timeout)
except Empty:
pass
else:
return key, item
def rpop(self, key):
try:
return self.queues[key].get_nowait()
except KeyError:
pass
def __contains__(self, k):
return k in self._called
def _new_queue(self, key):
self.queues[key] = _Queue()
class Channel(pyredis.Channel):
def _get_client(self):
return Client
def _get_response_error(self):
return ResponseError
def _new_queue(self, queue, **kwargs):
self.client._new_queue(queue)
class Transport(pyredis.Transport):
Channel = Channel
def _get_errors(self):
return ((), ())
class test_Redis(unittest.TestCase):
def setUp(self):
self.connection = BrokerConnection(transport=Transport)
self.exchange = Exchange("test_Redis", type="direct")
self.queue = Queue("test_Redis", self.exchange, "test_Redis")
def tearDown(self):
self.connection.close()
def test_publish__get(self):
channel = self.connection.channel()
producer = Producer(channel, self.exchange, routing_key="test_Redis")
self.queue(channel).declare()
producer.publish({"hello": "world"})
self.assertDictEqual(self.queue(channel).get().payload,
{"hello": "world"})
self.assertIsNone(self.queue(channel).get())
self.assertIsNone(self.queue(channel).get())
self.assertIsNone(self.queue(channel).get())
def test_publish__consume(self):
connection = BrokerConnection(transport=Transport)
channel = connection.channel()
producer = Producer(channel, self.exchange, routing_key="test_Redis")
consumer = Consumer(channel, self.queue)
producer.publish({"hello2": "world2"})
_received = []
def callback(message_data, message):
_received.append(message_data)
message.ack()
consumer.register_callback(callback)
consumer.consume()
self.assertTrue(channel._poller._can_start())
try:
connection.drain_events(timeout=1)
self.assertTrue(_received)
self.assertFalse(channel._poller._can_start())
self.assertRaises(socket.timeout,
connection.drain_events, timeout=0.01)
finally:
channel.close()
def test_purge(self):
channel = self.connection.channel()
producer = Producer(channel, self.exchange, routing_key="test_Redis")
self.queue(channel).declare()
for i in range(10):
producer.publish({"hello": "world-%s" % (i, )})
self.assertEqual(channel._size("test_Redis"), 10)
self.assertEqual(self.queue(channel).purge(), 10)
channel.close()
def test_db_values(self):
c1 = BrokerConnection(virtual_host=1,
transport=Transport).channel()
self.assertEqual(c1.client.db, 1)
c2 = BrokerConnection(virtual_host="1",
transport=Transport).channel()
self.assertEqual(c2.client.db, 1)
c3 = BrokerConnection(virtual_host="/1",
transport=Transport).channel()
self.assertEqual(c3.client.db, 1)
c4 = BrokerConnection(virtual_host="/foo",
transport=Transport).channel()
self.assertRaises(ValueError, getattr, c4, "client")
def test_db_port(self):
c1 = BrokerConnection(port=None, transport=Transport).channel()
self.assertEqual(c1.client.port, Transport.default_port)
c1.close()
c2 = BrokerConnection(port=9999, transport=Transport).channel()
self.assertEqual(c2.client.port, 9999)
c2.close()
def test_close_poller_not_active(self):
c = BrokerConnection(transport=Transport).channel()
c.close()
self.assertFalse(c._poller.isAlive())
self.assertTrue("BGSAVE" in c.client)
def test_close_ResponseError(self):
c = BrokerConnection(transport=Transport).channel()
c.client.bgsave_raises_ResponseError = True
c.close()
def test_get__Empty(self):
channel = self.connection.channel()
self.assertRaises(Empty, channel._get, "does-not-exist")
channel.close()
def test_get_client(self):
redis, exceptions = _redis_modules()
@module_exists(redis, exceptions)
def _do_test():
conn = BrokerConnection(transport="redis")
chan = conn.channel()
self.assertTrue(chan.Client)
self.assertTrue(chan.ResponseError)
self.assertTrue(conn.transport.connection_errors)
self.assertTrue(conn.transport.channel_errors)
_do_test()
def _redis_modules():
class ConnectionError(Exception):
pass
class AuthenticationError(Exception):
pass
class InvalidData(Exception):
pass
class InvalidResponse(Exception):
pass
class ResponseError(Exception):
pass
exceptions = types.ModuleType("redis.exceptions")
exceptions.ConnectionError = ConnectionError
exceptions.AuthenticationError = AuthenticationError
exceptions.InvalidData = InvalidData
exceptions.InvalidResponse = InvalidResponse
exceptions.ResponseError = ResponseError
class Redis(object):
pass
redis = types.ModuleType("redis")
redis.exceptions = exceptions
redis.Redis = Redis
return redis, exceptions

View File

@ -12,8 +12,6 @@ from threading import Condition, Event, Lock, Thread
from Queue import Empty, Queue as _Queue
from anyjson import serialize, deserialize
from redis import Redis
from redis import exceptions
from kombu.transport import virtual
@ -41,16 +39,19 @@ class ChannelPoller(Thread):
self.mutex = Lock()
self.poll_request = Condition(self.mutex)
self.shutdown = Event()
self.stopped = Event()
Thread.__init__(self)
self.setDaemon(False)
self.started = False
def run(self):
def run(self): # pragma: no cover
inbound = self.inbound
shutdown = self.shutdown
drain_events = self.drain_events
poll_request = self.poll_request
while 1:
if self.shutdown.isSet():
if shutdown.isSet():
break
try:
@ -58,9 +59,9 @@ class ChannelPoller(Thread):
except Empty:
pass
else:
self.inbound.put_nowait(item)
inbound.put_nowait(item)
if self.shutdown.isSet():
if shutdown.isSet():
break
# Wait for next poll request
@ -76,6 +77,8 @@ class ChannelPoller(Thread):
finally:
poll_request.release()
self.stopped.set()
def poll(self):
# start thread on demand.
self.ensure_started()
@ -91,19 +94,24 @@ class ChannelPoller(Thread):
# for it to put a message onto the inbound queue.
return self.inbound.get(timeout=0.3)
def _can_start(self):
return not (self.started or
self.isAlive() or
self.shutdown.isSet() or
self.stopped.isSet())
def ensure_started(self):
if not self.isAlive():
if self._can_start():
self.started = True
self.start()
def close(self):
self.shutdown.set()
if self.isAlive():
self.shutdown.set()
self.join()
class Channel(virtual.Channel):
Client = Redis
_client = None
supports_fanout = True
keyprefix_fanout = "_kombu.fanout.%s"
@ -115,6 +123,16 @@ class Channel(virtual.Channel):
super_.__init__(*args, **kwargs)
self._poller = ChannelPoller(super_.drain_events)
self.Client = self._get_client()
self.ResponseError = self._get_response_error()
def _get_client(self):
from redis import Redis
return Redis
def _get_response_error(self):
from redis import exceptions
return exceptions.ResponseError
def drain_events(self, timeout=None):
return self._poller.poll()
@ -129,9 +147,6 @@ class Channel(virtual.Channel):
members = self.client.smembers(self.keyprefix_queue % (exchange, ))
return [tuple(val.split(self.sep)) for val in members]
def _new_queue(self, queue, **kwargs):
pass
def _get(self, queue):
item = self.client.rpop(queue)
if item:
@ -161,7 +176,7 @@ class Channel(virtual.Channel):
super(Channel, self).close()
try:
self.client.bgsave()
except exceptions.ResponseError:
except self.ResponseError:
pass
def _open(self):
@ -195,9 +210,16 @@ class Transport(virtual.Transport):
interval = 1
default_port = DEFAULT_PORT
connection_errors = (exceptions.ConnectionError,
exceptions.AuthenticationError)
channel_errors = (exceptions.ConnectionError,
exceptions.InvalidData,
exceptions.InvalidResponse,
exceptions.ResponseError)
def __init__(self, *args, **kwargs):
self.connection_errors, self.channel_errors = self._get_errors()
super(Transport, self).__init__(*args, **kwargs)
def _get_errors(self):
from redis import exceptions
return ((exceptions.ConnectionError,
exceptions.AuthenticationError),
(exceptions.ConnectionError,
exceptions.InvalidData,
exceptions.InvalidResponse,
exceptions.ResponseError))

View File

@ -9,7 +9,7 @@ cover3-exclude = kombu
kombu.simple
kombu.utils.compat
kombu.utils.functional
kombu.transport.pyredis
kombu.utils.finalize
kombu.transport.pypika
kombu.transport.pycouchdb
kombu.transport.mongodb