mirror of https://github.com/celery/kombu.git
Adding a unit test for MongoDB transport. Also purging and getting fanout queue size properly.
This commit is contained in:
parent
a4a1381ff0
commit
03ae06c362
|
@ -0,0 +1,112 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import with_statement
|
||||
|
||||
import socket
|
||||
|
||||
from ..connection import BrokerConnection
|
||||
from ..entity import Exchange, Queue
|
||||
from ..messaging import Consumer, Producer
|
||||
|
||||
from .utils import TestCase
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
class test_MongoDBTransport(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.c = BrokerConnection(transport="mongodb")
|
||||
self.e = Exchange("test_transport_mongodb", type="fanout")
|
||||
self.q = Queue("test_transport_mongodb",
|
||||
exchange=self.e,
|
||||
routing_key="test_transport_mongodb")
|
||||
self.q2 = Queue("test_transport_memory2",
|
||||
exchange=self.e,
|
||||
routing_key="test_transport_mongodb2")
|
||||
|
||||
def test_fanout(self):
|
||||
return
|
||||
fanexch = Exchange("test_transport_mongodb_fanout", type="fanout")
|
||||
channel = self.c.channel()
|
||||
producer = Producer(channel, fanexch)
|
||||
_received = []
|
||||
|
||||
self.assertEqual(len(channel._queue_cursors), 0)
|
||||
badMsg = {"please":"noFails"}
|
||||
producer.publish(badMsg)
|
||||
|
||||
def callback(message_data, message):
|
||||
print "_callback"
|
||||
_received.append(message)
|
||||
message.ack()
|
||||
|
||||
goodMsg = {"something":"important"}
|
||||
fanqueue = Queue("test_transport_mongodb_fanq", exchange=fanexch, channel=channel)
|
||||
|
||||
consumer = Consumer(channel, fanqueue)
|
||||
consumer.register_callback(callback)
|
||||
consumer.register_callback(lambda p,x: sys.stdout("_derp"))
|
||||
|
||||
self.assertIn("test_transport_mongodb_fanq", channel._fanout_queues)
|
||||
|
||||
producer.publish(goodMsg)
|
||||
consumer.consume()
|
||||
channel.drain_events()
|
||||
|
||||
self.assertIn(goodMsg, _received)
|
||||
self.assertNotIn(badMsg, _received)
|
||||
|
||||
fanqueue.delete()
|
||||
self.assertEqual(len(channel._queue_cursors), 0)
|
||||
|
||||
|
||||
def test_produce_consume(self):
|
||||
channel = self.c.channel()
|
||||
producer = Producer(channel, self.e)
|
||||
consumer1 = Consumer(channel, self.q)
|
||||
consumer2 = Consumer(channel, self.q2)
|
||||
self.q2(channel).declare()
|
||||
|
||||
for i in range(10):
|
||||
producer.publish({"foo": i}, routing_key="test_transport_mongodb")
|
||||
for i in range(10):
|
||||
producer.publish({"foo": i}, routing_key="test_transport_mongodb2")
|
||||
|
||||
_received1 = []
|
||||
_received2 = []
|
||||
|
||||
def callback1(message_data, message):
|
||||
_received1.append(message)
|
||||
message.ack()
|
||||
|
||||
def callback2(message_data, message):
|
||||
_received2.append(message)
|
||||
message.ack()
|
||||
|
||||
consumer1.register_callback(callback1)
|
||||
consumer2.register_callback(callback2)
|
||||
|
||||
consumer1.consume()
|
||||
consumer2.consume()
|
||||
|
||||
while 1:
|
||||
if len(_received1) + len(_received2) == 20:
|
||||
break
|
||||
self.c.drain_events()
|
||||
|
||||
self.assertEqual(len(_received1) + len(_received2), 20)
|
||||
|
||||
# queue.delete
|
||||
for i in range(10):
|
||||
producer.publish({"foo": i}, routing_key="test_transport_mongodb")
|
||||
self.assertTrue(self.q(channel).get())
|
||||
self.q(channel).delete()
|
||||
self.q(channel).declare()
|
||||
self.assertIsNone(self.q(channel).get())
|
||||
|
||||
# queue.purge
|
||||
for i in range(10):
|
||||
producer.publish({"foo": i}, routing_key="test_transport_mongodb2")
|
||||
self.assertTrue(self.q2(channel).get())
|
||||
self.q2(channel).purge()
|
||||
self.assertIsNone(self.q2(channel).get())
|
|
@ -35,7 +35,7 @@ class Channel(virtual.Channel):
|
|||
super_.__init__(*vargs, **kwargs)
|
||||
|
||||
self._queue_cursors = {}
|
||||
|
||||
self._queue_readcounts = {}
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
pass
|
||||
|
@ -44,9 +44,10 @@ class Channel(virtual.Channel):
|
|||
try:
|
||||
if queue in self._fanout_queues:
|
||||
msg = self._queue_cursors[queue].next()
|
||||
self._queue_readcounts[queue]+=1
|
||||
return loads(msg["payload"])
|
||||
else:
|
||||
msg = self.client.database.command("findandmodify", "messages",
|
||||
msg = self.client.command("findandmodify", "messages",
|
||||
query={"queue": queue},
|
||||
sort={"_id": pymongo.ASCENDING}, remove=True)
|
||||
except errors.OperationFailure, exc:
|
||||
|
@ -62,14 +63,22 @@ class Channel(virtual.Channel):
|
|||
return loads(msg["value"]["payload"])
|
||||
|
||||
def _size(self, queue):
|
||||
return self.client.find({"queue": queue}).count()
|
||||
if queue in self._fanout_queues:
|
||||
return self._queue_cursors[queue].count() - self._queue_readcounts[queue]
|
||||
|
||||
return self.client.messages.find({"queue": queue}).count()
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
self.client.insert({"payload": dumps(message), "queue": queue})
|
||||
self.client.messages.insert({"payload": dumps(message), "queue": queue})
|
||||
|
||||
def _purge(self, queue):
|
||||
size = self._size(queue)
|
||||
self.client.remove({"queue": queue})
|
||||
if queue in self._fanout_queues:
|
||||
cursor = self._queue_cursors[queue]
|
||||
cursor.rewind()
|
||||
self._queue_cursors[queue] = cursor.skip(cursor.count())
|
||||
else:
|
||||
self.client.messages.remove({"queue": queue})
|
||||
return size
|
||||
|
||||
def close(self):
|
||||
|
@ -105,11 +114,11 @@ class Channel(virtual.Channel):
|
|||
|
||||
self.routing = getattr(database, "messages.routing")
|
||||
self.routing.ensure_index([("queue", 1), ("exchange", 1)])
|
||||
return col
|
||||
return database
|
||||
|
||||
def get_table(self, exchange):
|
||||
"""Get table of bindings for `exchange`."""
|
||||
brokerRoutes = self.routing.find({"exchange":exchange})
|
||||
brokerRoutes = self.client.messages.routing.find({"exchange":exchange})
|
||||
|
||||
localRoutes = self.state.exchanges[exchange]["table"]
|
||||
for route in brokerRoutes:
|
||||
|
@ -118,17 +127,18 @@ class Channel(virtual.Channel):
|
|||
|
||||
def _put_fanout(self, exchange, message, **kwargs):
|
||||
"""Deliver fanout message."""
|
||||
self.bcast.insert({"payload": serialize(message), "queue": exchange})
|
||||
self.client.messages.broadcast.insert({"payload": dumps(message), "queue": exchange})
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
if self.typeof(exchange).type == "fanout":
|
||||
cursor = self.bcast.find(query={"queue":exchange}, sort=[("$natural", 1)], tailable=True)
|
||||
# Fast forward the cursor past old events
|
||||
self._queue_cursors[queue] = cursor.skip(cursor.count())
|
||||
self._queue_readcounts[queue] = cursor.count()
|
||||
self._fanout_queues[queue] = exchange
|
||||
|
||||
meta = dict(exchange=exchange, queue=queue, routing_key=routing_key, pattern=pattern)
|
||||
self.routing.update(meta, meta, upsert=True)
|
||||
self.client.messages.routing.update(meta, meta, upsert=True)
|
||||
|
||||
|
||||
def queue_delete(self, queue, if_unusued=False, if_empty=False, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue