Redis: Fixes serious issue where larger messages may be lost when consuming using the Redis transport (ask/celery issue #318). Thanks to Simon Zimmerman, Andy McCurdy and Honza Král!

This commit is contained in:
Ask Solem 2011-04-06 15:59:16 +02:00
parent ee2ae37640
commit 348babd81b
4 changed files with 81 additions and 14 deletions

View File

@ -5,6 +5,7 @@ class test_beanstalk(transport.TransportCase):
transport = "beanstalk"
prefix = "beanstalk"
event_loop_max = 10
message_size_limit = 47662
def after_connect(self, connection):
connection.channel().client

View File

@ -1,11 +1,17 @@
from funtests import transport
from nose import SkipTest
class test_pika_blocking(transport.TransportCase):
transport = "syncpika"
prefix = "syncpika"
def test_produce__consume_large_messages(self, *args, **kwargs):
raise SkipTest("test currently fails for sync pika")
class test_pika_async(transport.TransportCase):
transport = "pika"
prefix = "pika"
def test_produce__consume_large_messages(self, *args, **kwargs):
raise SkipTest("test currently fails for async pika")

View File

@ -1,14 +1,27 @@
import random
import socket
import string
import sys
import time
import unittest2 as unittest
import warnings
from nose import SkipTest
from kombu import BrokerConnection
from kombu import Producer, Consumer, Exchange, Queue
if sys.version_info >= (2, 5):
from hashlib import sha256 as _digest
else:
from sha import new as _digest
def consumeN(conn, consumer, n=1):
def say(msg):
sys.stderr.write(unicode(msg) + "\n")
def consumeN(conn, consumer, n=1, timeout=30):
messages = []
def callback(message_data, message):
@ -18,11 +31,18 @@ def consumeN(conn, consumer, n=1):
prev, consumer.callbacks = consumer.callbacks, [callback]
consumer.consume()
seconds = 0
while True:
try:
conn.drain_events(timeout=1)
except socket.timeout:
pass
seconds += 1
msg = "Received %s/%s messages. %s seconds passed." % (
len(messages), n, seconds)
if seconds >= timeout:
raise socket.timeout(msg)
if seconds > 1:
say(msg)
if len(messages) >= n:
break
@ -40,6 +60,8 @@ class TransportCase(unittest.TestCase):
connected = False
skip_test_reason = None
message_size_limit = None
def before_connect(self):
pass
@ -76,13 +98,14 @@ class TransportCase(unittest.TestCase):
self.connected = True
def verify_alive(self):
if not self.connected:
raise SkipTest(self.skip_test_reason)
if self.transport:
if not self.connected:
raise SkipTest(self.skip_test_reason)
return True
def test_produce__consume(self):
if not self.transport:
if not self.verify_alive():
return
self.verify_alive()
chan1 = self.connection.channel()
consumer = Consumer(chan1, self.queue)
consumer.queues[0].purge()
@ -93,13 +116,54 @@ class TransportCase(unittest.TestCase):
chan1.close()
self.purge([self.queue.name])
def _digest(self, data):
return _digest(data).hexdigest()
def test_produce__consume_large_messages(self, bytes=1048576, n=10,
charset=string.punctuation + string.letters + string.digits):
if not self.verify_alive():
return
bytes = min(filter(None, [bytes, self.message_size_limit]))
messages = ["".join(random.choice(charset)
for j in xrange(bytes)) + "--%s" % n
for i in xrange(n)]
digests = []
chan1 = self.connection.channel()
consumer = Consumer(chan1, self.queue)
for queue in consumer.queues:
queue.purge()
producer = Producer(chan1, self.exchange)
for i, message in enumerate(messages):
producer.publish({"text": message,
"i": i}, routing_key=self.prefix)
digests.append(self._digest(message))
received = [(msg["i"], msg["text"])
for msg in consumeN(self.connection, consumer, n)]
self.assertEqual(len(received), n)
ordering = [i for i, _ in received]
if ordering != range(n):
warnings.warn(
"%s did not deliver messages in FIFO order: %r" % (
self.transport, ordering))
for i, text in received:
if text != messages[i]:
raise AssertionError("%i: %r is not %r" % (
i, text[-100:], messages[i][-100:]))
self.assertEqual(self._digest(text), digests[i])
chan1.close()
self.purge([self.queue.name])
def P(self, rest):
return "%s.%s" % (self.prefix, rest)
def test_produce__consume_multiple(self):
if not self.transport:
if not self.verify_alive():
return
self.verify_alive()
chan1 = self.connection.channel()
producer = Producer(chan1, self.exchange)
b1 = Queue(self.P("b1"), self.exchange, "b1")(chan1)
@ -121,9 +185,8 @@ class TransportCase(unittest.TestCase):
self.purge([self.P("b1"), self.P("b2"), self.P("b3")])
def test_timeout(self):
if not self.transport:
if not self.verify_alive():
return
self.verify_alive()
chan = self.connection.channel()
self.purge([self.queue.name])
consumer = Consumer(chan, self.queue)
@ -132,9 +195,8 @@ class TransportCase(unittest.TestCase):
consumer.cancel()
def test_basic_get(self):
if not self.transport:
if not self.verify_alive():
return
self.verify_alive()
chan1 = self.connection.channel()
producer = Producer(chan1, self.exchange)
chan2 = self.connection.channel()

View File

@ -62,7 +62,6 @@ class MultiChannelPoller(object):
if client.connection._sock is None: # not connected yet.
client.connection.connect(client)
sock = client.connection._sock
sock.setblocking(0)
self._fd_to_chan[sock.fileno()] = (channel, type)
self._chan_to_sock[(channel, client, type)] = sock
self._poller.register(sock, self.eventflags)
@ -161,7 +160,6 @@ class Channel(virtual.Channel):
c = self.subclient
if c.connection._sock is None:
c.connection.connect(c)
c.connection._sock.setblocking(0)
self.subclient.subscribe(keys)
self._in_listen = True