mirror of https://github.com/celery/kombu.git
100% unit test coverage for Redis transport
This commit is contained in:
parent
a77bfc93a6
commit
4e32fa3340
|
@ -0,0 +1,243 @@
|
|||
import socket
|
||||
import types
|
||||
import unittest2 as unittest
|
||||
|
||||
from Queue import Empty, Queue as _Queue
|
||||
|
||||
from kombu.connection import BrokerConnection
|
||||
from kombu.entity import Exchange, Queue
|
||||
from kombu.messaging import Consumer, Producer
|
||||
|
||||
from kombu.transport import pyredis
|
||||
|
||||
from kombu.tests.utils import module_exists
|
||||
|
||||
|
||||
class ResponseError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Client(object):
|
||||
|
||||
def __init__(self, db=None, port=None, **kwargs):
|
||||
self.port = port
|
||||
self.db = db
|
||||
self._called = []
|
||||
self.queues = {}
|
||||
self.sets = {}
|
||||
self.bgsave_raises_ResponseError = False
|
||||
|
||||
def bgsave(self):
|
||||
self._called.append("BGSAVE")
|
||||
if self.bgsave_raises_ResponseError:
|
||||
raise ResponseError()
|
||||
|
||||
def delete(self, key):
|
||||
self.queues.pop(key, None)
|
||||
|
||||
def sadd(self, key, member):
|
||||
if key not in self.sets:
|
||||
self.sets[key] = set()
|
||||
self.sets[key].add(member)
|
||||
|
||||
def smembers(self, key):
|
||||
return self.sets.get(key, set())
|
||||
|
||||
def llen(self, key):
|
||||
return self.queues[key].qsize()
|
||||
|
||||
def lpush(self, key, value):
|
||||
self.queues[key].put_nowait(value)
|
||||
|
||||
def brpop(self, keys, timeout=None):
|
||||
key = keys[0]
|
||||
try:
|
||||
item = self.queues[key].get(timeout=timeout)
|
||||
except Empty:
|
||||
pass
|
||||
else:
|
||||
return key, item
|
||||
|
||||
def rpop(self, key):
|
||||
try:
|
||||
return self.queues[key].get_nowait()
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def __contains__(self, k):
|
||||
return k in self._called
|
||||
|
||||
def _new_queue(self, key):
|
||||
self.queues[key] = _Queue()
|
||||
|
||||
|
||||
class Channel(pyredis.Channel):
|
||||
|
||||
def _get_client(self):
|
||||
return Client
|
||||
|
||||
def _get_response_error(self):
|
||||
return ResponseError
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
self.client._new_queue(queue)
|
||||
|
||||
|
||||
class Transport(pyredis.Transport):
|
||||
Channel = Channel
|
||||
|
||||
def _get_errors(self):
|
||||
return ((), ())
|
||||
|
||||
|
||||
class test_Redis(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.connection = BrokerConnection(transport=Transport)
|
||||
self.exchange = Exchange("test_Redis", type="direct")
|
||||
self.queue = Queue("test_Redis", self.exchange, "test_Redis")
|
||||
|
||||
def tearDown(self):
|
||||
self.connection.close()
|
||||
|
||||
def test_publish__get(self):
|
||||
channel = self.connection.channel()
|
||||
producer = Producer(channel, self.exchange, routing_key="test_Redis")
|
||||
self.queue(channel).declare()
|
||||
|
||||
producer.publish({"hello": "world"})
|
||||
|
||||
self.assertDictEqual(self.queue(channel).get().payload,
|
||||
{"hello": "world"})
|
||||
self.assertIsNone(self.queue(channel).get())
|
||||
self.assertIsNone(self.queue(channel).get())
|
||||
self.assertIsNone(self.queue(channel).get())
|
||||
|
||||
def test_publish__consume(self):
|
||||
connection = BrokerConnection(transport=Transport)
|
||||
channel = connection.channel()
|
||||
producer = Producer(channel, self.exchange, routing_key="test_Redis")
|
||||
consumer = Consumer(channel, self.queue)
|
||||
|
||||
producer.publish({"hello2": "world2"})
|
||||
_received = []
|
||||
|
||||
def callback(message_data, message):
|
||||
_received.append(message_data)
|
||||
message.ack()
|
||||
|
||||
consumer.register_callback(callback)
|
||||
consumer.consume()
|
||||
|
||||
self.assertTrue(channel._poller._can_start())
|
||||
try:
|
||||
connection.drain_events(timeout=1)
|
||||
self.assertTrue(_received)
|
||||
self.assertFalse(channel._poller._can_start())
|
||||
self.assertRaises(socket.timeout,
|
||||
connection.drain_events, timeout=0.01)
|
||||
finally:
|
||||
channel.close()
|
||||
|
||||
def test_purge(self):
|
||||
channel = self.connection.channel()
|
||||
producer = Producer(channel, self.exchange, routing_key="test_Redis")
|
||||
self.queue(channel).declare()
|
||||
|
||||
for i in range(10):
|
||||
producer.publish({"hello": "world-%s" % (i, )})
|
||||
|
||||
self.assertEqual(channel._size("test_Redis"), 10)
|
||||
self.assertEqual(self.queue(channel).purge(), 10)
|
||||
channel.close()
|
||||
|
||||
def test_db_values(self):
|
||||
c1 = BrokerConnection(virtual_host=1,
|
||||
transport=Transport).channel()
|
||||
self.assertEqual(c1.client.db, 1)
|
||||
|
||||
c2 = BrokerConnection(virtual_host="1",
|
||||
transport=Transport).channel()
|
||||
self.assertEqual(c2.client.db, 1)
|
||||
|
||||
c3 = BrokerConnection(virtual_host="/1",
|
||||
transport=Transport).channel()
|
||||
self.assertEqual(c3.client.db, 1)
|
||||
|
||||
c4 = BrokerConnection(virtual_host="/foo",
|
||||
transport=Transport).channel()
|
||||
self.assertRaises(ValueError, getattr, c4, "client")
|
||||
|
||||
def test_db_port(self):
|
||||
c1 = BrokerConnection(port=None, transport=Transport).channel()
|
||||
self.assertEqual(c1.client.port, Transport.default_port)
|
||||
c1.close()
|
||||
|
||||
c2 = BrokerConnection(port=9999, transport=Transport).channel()
|
||||
self.assertEqual(c2.client.port, 9999)
|
||||
c2.close()
|
||||
|
||||
def test_close_poller_not_active(self):
|
||||
c = BrokerConnection(transport=Transport).channel()
|
||||
c.close()
|
||||
self.assertFalse(c._poller.isAlive())
|
||||
self.assertTrue("BGSAVE" in c.client)
|
||||
|
||||
def test_close_ResponseError(self):
|
||||
c = BrokerConnection(transport=Transport).channel()
|
||||
c.client.bgsave_raises_ResponseError = True
|
||||
c.close()
|
||||
|
||||
def test_get__Empty(self):
|
||||
channel = self.connection.channel()
|
||||
self.assertRaises(Empty, channel._get, "does-not-exist")
|
||||
channel.close()
|
||||
|
||||
def test_get_client(self):
|
||||
|
||||
redis, exceptions = _redis_modules()
|
||||
|
||||
@module_exists(redis, exceptions)
|
||||
def _do_test():
|
||||
conn = BrokerConnection(transport="redis")
|
||||
chan = conn.channel()
|
||||
self.assertTrue(chan.Client)
|
||||
self.assertTrue(chan.ResponseError)
|
||||
self.assertTrue(conn.transport.connection_errors)
|
||||
self.assertTrue(conn.transport.channel_errors)
|
||||
|
||||
_do_test()
|
||||
|
||||
|
||||
def _redis_modules():
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidData(Exception):
|
||||
pass
|
||||
|
||||
class InvalidResponse(Exception):
|
||||
pass
|
||||
|
||||
class ResponseError(Exception):
|
||||
pass
|
||||
|
||||
exceptions = types.ModuleType("redis.exceptions")
|
||||
exceptions.ConnectionError = ConnectionError
|
||||
exceptions.AuthenticationError = AuthenticationError
|
||||
exceptions.InvalidData = InvalidData
|
||||
exceptions.InvalidResponse = InvalidResponse
|
||||
exceptions.ResponseError = ResponseError
|
||||
|
||||
class Redis(object):
|
||||
pass
|
||||
|
||||
redis = types.ModuleType("redis")
|
||||
redis.exceptions = exceptions
|
||||
redis.Redis = Redis
|
||||
|
||||
return redis, exceptions
|
|
@ -12,8 +12,6 @@ from threading import Condition, Event, Lock, Thread
|
|||
from Queue import Empty, Queue as _Queue
|
||||
|
||||
from anyjson import serialize, deserialize
|
||||
from redis import Redis
|
||||
from redis import exceptions
|
||||
|
||||
from kombu.transport import virtual
|
||||
|
||||
|
@ -41,16 +39,19 @@ class ChannelPoller(Thread):
|
|||
self.mutex = Lock()
|
||||
self.poll_request = Condition(self.mutex)
|
||||
self.shutdown = Event()
|
||||
self.stopped = Event()
|
||||
Thread.__init__(self)
|
||||
self.setDaemon(False)
|
||||
self.started = False
|
||||
|
||||
def run(self):
|
||||
def run(self): # pragma: no cover
|
||||
inbound = self.inbound
|
||||
shutdown = self.shutdown
|
||||
drain_events = self.drain_events
|
||||
poll_request = self.poll_request
|
||||
|
||||
while 1:
|
||||
if self.shutdown.isSet():
|
||||
if shutdown.isSet():
|
||||
break
|
||||
|
||||
try:
|
||||
|
@ -58,9 +59,9 @@ class ChannelPoller(Thread):
|
|||
except Empty:
|
||||
pass
|
||||
else:
|
||||
self.inbound.put_nowait(item)
|
||||
inbound.put_nowait(item)
|
||||
|
||||
if self.shutdown.isSet():
|
||||
if shutdown.isSet():
|
||||
break
|
||||
|
||||
# Wait for next poll request
|
||||
|
@ -76,6 +77,8 @@ class ChannelPoller(Thread):
|
|||
finally:
|
||||
poll_request.release()
|
||||
|
||||
self.stopped.set()
|
||||
|
||||
def poll(self):
|
||||
# start thread on demand.
|
||||
self.ensure_started()
|
||||
|
@ -91,19 +94,24 @@ class ChannelPoller(Thread):
|
|||
# for it to put a message onto the inbound queue.
|
||||
return self.inbound.get(timeout=0.3)
|
||||
|
||||
def _can_start(self):
|
||||
return not (self.started or
|
||||
self.isAlive() or
|
||||
self.shutdown.isSet() or
|
||||
self.stopped.isSet())
|
||||
|
||||
def ensure_started(self):
|
||||
if not self.isAlive():
|
||||
if self._can_start():
|
||||
self.started = True
|
||||
self.start()
|
||||
|
||||
def close(self):
|
||||
self.shutdown.set()
|
||||
if self.isAlive():
|
||||
self.shutdown.set()
|
||||
self.join()
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
Client = Redis
|
||||
|
||||
_client = None
|
||||
supports_fanout = True
|
||||
keyprefix_fanout = "_kombu.fanout.%s"
|
||||
|
@ -115,6 +123,16 @@ class Channel(virtual.Channel):
|
|||
super_.__init__(*args, **kwargs)
|
||||
|
||||
self._poller = ChannelPoller(super_.drain_events)
|
||||
self.Client = self._get_client()
|
||||
self.ResponseError = self._get_response_error()
|
||||
|
||||
def _get_client(self):
|
||||
from redis import Redis
|
||||
return Redis
|
||||
|
||||
def _get_response_error(self):
|
||||
from redis import exceptions
|
||||
return exceptions.ResponseError
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
return self._poller.poll()
|
||||
|
@ -129,9 +147,6 @@ class Channel(virtual.Channel):
|
|||
members = self.client.smembers(self.keyprefix_queue % (exchange, ))
|
||||
return [tuple(val.split(self.sep)) for val in members]
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
pass
|
||||
|
||||
def _get(self, queue):
|
||||
item = self.client.rpop(queue)
|
||||
if item:
|
||||
|
@ -161,7 +176,7 @@ class Channel(virtual.Channel):
|
|||
super(Channel, self).close()
|
||||
try:
|
||||
self.client.bgsave()
|
||||
except exceptions.ResponseError:
|
||||
except self.ResponseError:
|
||||
pass
|
||||
|
||||
def _open(self):
|
||||
|
@ -195,9 +210,16 @@ class Transport(virtual.Transport):
|
|||
|
||||
interval = 1
|
||||
default_port = DEFAULT_PORT
|
||||
connection_errors = (exceptions.ConnectionError,
|
||||
exceptions.AuthenticationError)
|
||||
channel_errors = (exceptions.ConnectionError,
|
||||
exceptions.InvalidData,
|
||||
exceptions.InvalidResponse,
|
||||
exceptions.ResponseError)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.connection_errors, self.channel_errors = self._get_errors()
|
||||
super(Transport, self).__init__(*args, **kwargs)
|
||||
|
||||
def _get_errors(self):
|
||||
from redis import exceptions
|
||||
return ((exceptions.ConnectionError,
|
||||
exceptions.AuthenticationError),
|
||||
(exceptions.ConnectionError,
|
||||
exceptions.InvalidData,
|
||||
exceptions.InvalidResponse,
|
||||
exceptions.ResponseError))
|
||||
|
|
Loading…
Reference in New Issue