Adding a unit test for MongoDB transport. Also purging and getting fanout queue size properly.

This commit is contained in:
Scott Lyons 2012-01-31 12:51:27 -05:00
parent a4a1381ff0
commit 03ae06c362
2 changed files with 131 additions and 9 deletions

View File

@ -0,0 +1,112 @@
from __future__ import absolute_import
from __future__ import with_statement
import socket
from ..connection import BrokerConnection
from ..entity import Exchange, Queue
from ..messaging import Consumer, Producer
from .utils import TestCase
import sys
class test_MongoDBTransport(TestCase):
def setUp(self):
self.c = BrokerConnection(transport="mongodb")
self.e = Exchange("test_transport_mongodb", type="fanout")
self.q = Queue("test_transport_mongodb",
exchange=self.e,
routing_key="test_transport_mongodb")
self.q2 = Queue("test_transport_memory2",
exchange=self.e,
routing_key="test_transport_mongodb2")
def test_fanout(self):
return
fanexch = Exchange("test_transport_mongodb_fanout", type="fanout")
channel = self.c.channel()
producer = Producer(channel, fanexch)
_received = []
self.assertEqual(len(channel._queue_cursors), 0)
badMsg = {"please":"noFails"}
producer.publish(badMsg)
def callback(message_data, message):
print "_callback"
_received.append(message)
message.ack()
goodMsg = {"something":"important"}
fanqueue = Queue("test_transport_mongodb_fanq", exchange=fanexch, channel=channel)
consumer = Consumer(channel, fanqueue)
consumer.register_callback(callback)
consumer.register_callback(lambda p,x: sys.stdout("_derp"))
self.assertIn("test_transport_mongodb_fanq", channel._fanout_queues)
producer.publish(goodMsg)
consumer.consume()
channel.drain_events()
self.assertIn(goodMsg, _received)
self.assertNotIn(badMsg, _received)
fanqueue.delete()
self.assertEqual(len(channel._queue_cursors), 0)
def test_produce_consume(self):
channel = self.c.channel()
producer = Producer(channel, self.e)
consumer1 = Consumer(channel, self.q)
consumer2 = Consumer(channel, self.q2)
self.q2(channel).declare()
for i in range(10):
producer.publish({"foo": i}, routing_key="test_transport_mongodb")
for i in range(10):
producer.publish({"foo": i}, routing_key="test_transport_mongodb2")
_received1 = []
_received2 = []
def callback1(message_data, message):
_received1.append(message)
message.ack()
def callback2(message_data, message):
_received2.append(message)
message.ack()
consumer1.register_callback(callback1)
consumer2.register_callback(callback2)
consumer1.consume()
consumer2.consume()
while 1:
if len(_received1) + len(_received2) == 20:
break
self.c.drain_events()
self.assertEqual(len(_received1) + len(_received2), 20)
# queue.delete
for i in range(10):
producer.publish({"foo": i}, routing_key="test_transport_mongodb")
self.assertTrue(self.q(channel).get())
self.q(channel).delete()
self.q(channel).declare()
self.assertIsNone(self.q(channel).get())
# queue.purge
for i in range(10):
producer.publish({"foo": i}, routing_key="test_transport_mongodb2")
self.assertTrue(self.q2(channel).get())
self.q2(channel).purge()
self.assertIsNone(self.q2(channel).get())

View File

@ -35,7 +35,7 @@ class Channel(virtual.Channel):
super_.__init__(*vargs, **kwargs)
self._queue_cursors = {}
self._queue_readcounts = {}
def _new_queue(self, queue, **kwargs):
pass
@ -44,9 +44,10 @@ class Channel(virtual.Channel):
try:
if queue in self._fanout_queues:
msg = self._queue_cursors[queue].next()
self._queue_readcounts[queue]+=1
return loads(msg["payload"])
else:
msg = self.client.database.command("findandmodify", "messages",
msg = self.client.command("findandmodify", "messages",
query={"queue": queue},
sort={"_id": pymongo.ASCENDING}, remove=True)
except errors.OperationFailure, exc:
@ -62,14 +63,22 @@ class Channel(virtual.Channel):
return loads(msg["value"]["payload"])
def _size(self, queue):
return self.client.find({"queue": queue}).count()
if queue in self._fanout_queues:
return self._queue_cursors[queue].count() - self._queue_readcounts[queue]
return self.client.messages.find({"queue": queue}).count()
def _put(self, queue, message, **kwargs):
self.client.insert({"payload": dumps(message), "queue": queue})
self.client.messages.insert({"payload": dumps(message), "queue": queue})
def _purge(self, queue):
size = self._size(queue)
self.client.remove({"queue": queue})
if queue in self._fanout_queues:
cursor = self._queue_cursors[queue]
cursor.rewind()
self._queue_cursors[queue] = cursor.skip(cursor.count())
else:
self.client.messages.remove({"queue": queue})
return size
def close(self):
@ -105,11 +114,11 @@ class Channel(virtual.Channel):
self.routing = getattr(database, "messages.routing")
self.routing.ensure_index([("queue", 1), ("exchange", 1)])
return col
return database
def get_table(self, exchange):
"""Get table of bindings for `exchange`."""
brokerRoutes = self.routing.find({"exchange":exchange})
brokerRoutes = self.client.messages.routing.find({"exchange":exchange})
localRoutes = self.state.exchanges[exchange]["table"]
for route in brokerRoutes:
@ -118,17 +127,18 @@ class Channel(virtual.Channel):
def _put_fanout(self, exchange, message, **kwargs):
"""Deliver fanout message."""
self.bcast.insert({"payload": serialize(message), "queue": exchange})
self.client.messages.broadcast.insert({"payload": dumps(message), "queue": exchange})
def _queue_bind(self, exchange, routing_key, pattern, queue):
if self.typeof(exchange).type == "fanout":
cursor = self.bcast.find(query={"queue":exchange}, sort=[("$natural", 1)], tailable=True)
# Fast forward the cursor past old events
self._queue_cursors[queue] = cursor.skip(cursor.count())
self._queue_readcounts[queue] = cursor.count()
self._fanout_queues[queue] = exchange
meta = dict(exchange=exchange, queue=queue, routing_key=routing_key, pattern=pattern)
self.routing.update(meta, meta, upsert=True)
self.client.messages.routing.update(meta, meta, upsert=True)
def queue_delete(self, queue, if_unusued=False, if_empty=False, **kwargs):