mirror of https://github.com/celery/kombu.git
Virtual transport deliver now calls callback, no return value. Fixes #593
This commit is contained in:
parent
f71ea5c803
commit
84fc45b9a7
|
@ -37,7 +37,6 @@ SQS Features supported by this transport:
|
|||
|
||||
from __future__ import absolute_import, unicode_literals
|
||||
|
||||
import collections
|
||||
import socket
|
||||
import string
|
||||
|
||||
|
@ -99,12 +98,6 @@ class Channel(virtual.Channel):
|
|||
# queues that are known to already exist.
|
||||
self._update_queue_cache(self.queue_name_prefix)
|
||||
|
||||
# The drain_events() method stores extra messages in a local
|
||||
# Deque object. This allows multiple messages to be requested from
|
||||
# SQS at once for performance, but maintains the same external API
|
||||
# to the caller of the drain_events() method.
|
||||
self._queue_message_cache = collections.deque()
|
||||
|
||||
self.hub = kwargs.get('hub') or get_event_loop()
|
||||
|
||||
def _update_queue_cache(self, queue_name_prefix):
|
||||
|
@ -145,24 +138,9 @@ class Channel(virtual.Channel):
|
|||
# If we're not allowed to consume or have no consumers, raise Empty
|
||||
if not self._consumers or not self.qos.can_consume():
|
||||
raise Empty()
|
||||
message_cache = self._queue_message_cache
|
||||
|
||||
# Check if there are any items in our buffer. If there are any, pop
|
||||
# off that queue first.
|
||||
try:
|
||||
return message_cache.popleft()
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
# At this point, go and get more messages from SQS
|
||||
res, queue = self._poll(self.cycle, timeout=timeout)
|
||||
message_cache.extend((r, queue) for r in res)
|
||||
|
||||
# Now try to pop off the queue again.
|
||||
try:
|
||||
return message_cache.popleft()
|
||||
except IndexError:
|
||||
raise Empty()
|
||||
self._poll(self.cycle, self.connection._deliver, timeout=timeout)
|
||||
|
||||
def _reset_cycle(self):
|
||||
"""Reset the consume cycle.
|
||||
|
@ -286,7 +264,9 @@ class Channel(virtual.Channel):
|
|||
messages = q.get_messages(num_messages=maxcount)
|
||||
|
||||
if messages:
|
||||
return self._messages_to_python(messages, queue)
|
||||
for msg in self._messages_to_python(messages, queue):
|
||||
self.connection._deliver(msg, queue)
|
||||
return
|
||||
raise Empty()
|
||||
|
||||
def _get(self, queue):
|
||||
|
|
|
@ -325,7 +325,7 @@ class MultiChannelPoller(object):
|
|||
def on_readable(self, fileno):
|
||||
chan, type = self._fd_to_chan[fileno]
|
||||
if chan.qos.can_consume():
|
||||
return chan.handlers[type]()
|
||||
chan.handlers[type]()
|
||||
|
||||
def handle_event(self, fileno, event):
|
||||
if event & READ:
|
||||
|
@ -334,7 +334,7 @@ class MultiChannelPoller(object):
|
|||
chan, type = self._fd_to_chan[fileno]
|
||||
chan._poll_error(type)
|
||||
|
||||
def get(self, timeout=None):
|
||||
def get(self, callback, timeout=None):
|
||||
self._in_protected_read = True
|
||||
try:
|
||||
for channel in self._channels:
|
||||
|
@ -345,15 +345,14 @@ class MultiChannelPoller(object):
|
|||
self._register_LISTEN(channel)
|
||||
|
||||
events = self.poller.poll(timeout)
|
||||
for fileno, event in events or []:
|
||||
ret = self.handle_event(fileno, event)
|
||||
if ret:
|
||||
return ret
|
||||
|
||||
if events:
|
||||
for fileno, event in events:
|
||||
ret = self.handle_event(fileno, event)
|
||||
if ret:
|
||||
return
|
||||
# - no new data, so try to restore messages.
|
||||
# - reset active redis commands.
|
||||
self.maybe_restore_messages()
|
||||
|
||||
raise Empty()
|
||||
finally:
|
||||
self._in_protected_read = False
|
||||
|
@ -660,6 +659,16 @@ class Channel(virtual.Channel):
|
|||
|
||||
def _receive(self):
|
||||
c = self.subclient
|
||||
ret = []
|
||||
try:
|
||||
ret.append(self._receive_one(c))
|
||||
except Empty:
|
||||
pass
|
||||
while c.connection.can_read(timeout=0):
|
||||
ret.append(self._receive_one(c))
|
||||
return any(ret)
|
||||
|
||||
def _receive_one(self, c):
|
||||
response = None
|
||||
try:
|
||||
response = c.parse_response()
|
||||
|
@ -680,8 +689,9 @@ class Channel(virtual.Channel):
|
|||
channel, repr(payload)[:4096], exc_info=1)
|
||||
raise Empty()
|
||||
exchange = channel.split('/', 1)[0]
|
||||
return message, self._fanout_to_queue[exchange]
|
||||
raise Empty()
|
||||
self.connection._deliver(
|
||||
message, self._fanout_to_queue[exchange])
|
||||
return True
|
||||
|
||||
def _brpop_start(self, timeout=1):
|
||||
queues = self._queue_cycle.consume(len(self.active_queues))
|
||||
|
@ -707,7 +717,8 @@ class Channel(virtual.Channel):
|
|||
dest, item = dest__item
|
||||
dest = bytes_to_str(dest).rsplit(self.sep, 1)[0]
|
||||
self._queue_cycle.rotate(dest)
|
||||
return loads(bytes_to_str(item)), dest
|
||||
self.connection._deliver(loads(bytes_to_str(item)), dest)
|
||||
return True
|
||||
else:
|
||||
raise Empty()
|
||||
finally:
|
||||
|
@ -1033,9 +1044,7 @@ class Transport(virtual.Transport):
|
|||
|
||||
def on_readable(self, fileno):
|
||||
"""Handle AIO event for one of our file descriptors."""
|
||||
item = self.cycle.on_readable(fileno)
|
||||
if item:
|
||||
self._deliver(*item)
|
||||
self.cycle.on_readable(fileno)
|
||||
|
||||
def _get_errors(self):
|
||||
"""Utility to import redis-py's exceptions at runtime."""
|
||||
|
|
|
@ -393,9 +393,13 @@ class AbstractChannel(object):
|
|||
"""
|
||||
return True
|
||||
|
||||
def _poll(self, cycle, timeout=None):
|
||||
def _poll(self, cycle, callback, timeout=None):
|
||||
"""Poll a list of queues for available messages."""
|
||||
return cycle.get()
|
||||
return cycle.get(callback)
|
||||
|
||||
def _get_and_deliver(self, queue, callback):
|
||||
message = self._get(queue)
|
||||
callback(message, queue)
|
||||
|
||||
|
||||
class Channel(AbstractChannel, base.StdChannel):
|
||||
|
@ -590,6 +594,15 @@ class Channel(AbstractChannel, base.StdChannel):
|
|||
|
||||
def basic_publish(self, message, exchange, routing_key, **kwargs):
|
||||
"""Publish message."""
|
||||
self._inplace_augment_message(message, exchange, routing_key)
|
||||
if exchange:
|
||||
return self.typeof(exchange).deliver(
|
||||
message, exchange, routing_key, **kwargs
|
||||
)
|
||||
# anon exchange: routing_key is the destination queue
|
||||
return self._put(routing_key, message, **kwargs)
|
||||
|
||||
def _inplace_augment_message(self, message, exchange, routing_key):
|
||||
message['body'], body_encoding = self.encode_body(
|
||||
message['body'], self.body_encoding,
|
||||
)
|
||||
|
@ -602,12 +615,6 @@ class Channel(AbstractChannel, base.StdChannel):
|
|||
exchange=exchange,
|
||||
routing_key=routing_key,
|
||||
)
|
||||
if exchange:
|
||||
return self.typeof(exchange).deliver(
|
||||
message, exchange, routing_key, **kwargs
|
||||
)
|
||||
# anon exchange: routing_key is the destination queue
|
||||
return self._put(routing_key, message, **kwargs)
|
||||
|
||||
def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs):
|
||||
"""Consume from `queue`"""
|
||||
|
@ -725,11 +732,12 @@ class Channel(AbstractChannel, base.StdChannel):
|
|||
def _restore_at_beginning(self, message):
|
||||
return self._restore(message)
|
||||
|
||||
def drain_events(self, timeout=None):
|
||||
def drain_events(self, timeout=None, callback=None):
|
||||
callback = callback or self.connection._deliver
|
||||
if self._consumers and self.qos.can_consume():
|
||||
if hasattr(self, '_get_many'):
|
||||
return self._get_many(self._active_queues, timeout=timeout)
|
||||
return self._poll(self.cycle, timeout=timeout)
|
||||
return self._poll(self.cycle, callback, timeout=timeout)
|
||||
raise Empty()
|
||||
|
||||
def message_to_python(self, raw_message):
|
||||
|
@ -787,7 +795,8 @@ class Channel(AbstractChannel, base.StdChannel):
|
|||
return body
|
||||
|
||||
def _reset_cycle(self):
|
||||
self._cycle = FairCycle(self._get, self._active_queues, Empty)
|
||||
self._cycle = FairCycle(
|
||||
self._get_and_deliver, self._active_queues, Empty)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -935,22 +944,19 @@ class Transport(base.Transport):
|
|||
channel.close()
|
||||
|
||||
def drain_events(self, connection, timeout=None):
|
||||
loop = 0
|
||||
time_start = monotonic()
|
||||
get = self.cycle.get
|
||||
polling_interval = self.polling_interval
|
||||
while 1:
|
||||
try:
|
||||
item, channel = get(timeout=timeout)
|
||||
get(self._deliver, timeout=timeout)
|
||||
except Empty:
|
||||
if timeout and monotonic() - time_start >= timeout:
|
||||
if timeout is not None and monotonic() - time_start >= timeout:
|
||||
raise socket.timeout()
|
||||
loop += 1
|
||||
if polling_interval is not None:
|
||||
sleep(polling_interval)
|
||||
else:
|
||||
break
|
||||
self._deliver(*item)
|
||||
|
||||
def _deliver(self, message, queue):
|
||||
if not queue:
|
||||
|
@ -980,8 +986,8 @@ class Transport(base.Transport):
|
|||
queue, message))
|
||||
self._callbacks[queue](message)
|
||||
|
||||
def _drain_channel(self, channel, timeout=None):
|
||||
return channel.drain_events(timeout=timeout)
|
||||
def _drain_channel(self, channel, callback, timeout=None):
|
||||
return channel.drain_events(callback=callback, timeout=timeout)
|
||||
|
||||
@property
|
||||
def default_connection_params(self):
|
||||
|
|
|
@ -40,15 +40,19 @@ class FairCycle(object):
|
|||
if not self.resources:
|
||||
raise self.predicate()
|
||||
|
||||
def get(self, **kwargs):
|
||||
def get(self, callback, **kwargs):
|
||||
succeeded = 0
|
||||
for tried in count(0): # for infinity
|
||||
resource = self._next()
|
||||
|
||||
try:
|
||||
return self.fun(resource, **kwargs), resource
|
||||
return self.fun(resource, callback, **kwargs)
|
||||
except self.predicate:
|
||||
if tried >= len(self.resources) - 1:
|
||||
raise
|
||||
if not succeeded:
|
||||
raise
|
||||
break
|
||||
else:
|
||||
succeeded += 1
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
|
|
@ -9,14 +9,14 @@ from __future__ import absolute_import, unicode_literals
|
|||
|
||||
import pytest
|
||||
|
||||
from case import skip
|
||||
from case import Mock, skip
|
||||
|
||||
from kombu import five
|
||||
from kombu import messaging
|
||||
from kombu import Connection, Exchange, Queue
|
||||
|
||||
from kombu.transport import SQS
|
||||
from kombu.async.aws.ext import exception
|
||||
from kombu.five import Empty
|
||||
from kombu.transport import SQS
|
||||
|
||||
|
||||
class SQSQueueMock(object):
|
||||
|
@ -48,7 +48,9 @@ class SQSQueueMock(object):
|
|||
def get_messages(self, num_messages=1, visibility_timeout=None,
|
||||
attributes=None, *args, **kwargs):
|
||||
self._get_message_calls += 1
|
||||
return self.messages[:num_messages]
|
||||
messages, self.messages[:num_messages] = (
|
||||
self.messages[:num_messages], [])
|
||||
return messages
|
||||
|
||||
def read(self, visibility_timeout=None):
|
||||
return self.messages.pop(0)
|
||||
|
@ -222,11 +224,11 @@ class test_Channel:
|
|||
assert len(results) == 3
|
||||
|
||||
def test_get_with_empty_list(self):
|
||||
with pytest.raises(five.Empty):
|
||||
with pytest.raises(Empty):
|
||||
self.channel._get(self.queue_name)
|
||||
|
||||
def test_get_bulk_raises_empty(self):
|
||||
with pytest.raises(five.Empty):
|
||||
with pytest.raises(Empty):
|
||||
self.channel._get_bulk(self.queue_name)
|
||||
|
||||
def test_messages_to_python(self):
|
||||
|
@ -239,7 +241,7 @@ class test_Channel:
|
|||
|
||||
# json formatted message NOT created by kombu
|
||||
for i in range(json_message_count):
|
||||
message = '{"foo":"bar"}'
|
||||
message = {'foo': 'bar'}
|
||||
self.channel._put(self.producer.routing_key, message)
|
||||
|
||||
q = self.channel._new_queue(self.queue_name)
|
||||
|
@ -275,8 +277,9 @@ class test_Channel:
|
|||
# With QoS.prefetch_count = 0
|
||||
message = 'my test message'
|
||||
self.producer.publish(message)
|
||||
results = self.channel._get_bulk(self.queue_name)
|
||||
assert 1 == len(results)
|
||||
self.channel.connection._deliver = Mock(name='_deliver')
|
||||
self.channel._get_bulk(self.queue_name)
|
||||
self.channel.connection._deliver.assert_called_once()
|
||||
|
||||
def test_puts_and_get_bulk(self):
|
||||
# Generate 8 messages
|
||||
|
@ -292,57 +295,88 @@ class test_Channel:
|
|||
|
||||
# Count how many messages are retrieved the first time. Should
|
||||
# be 5 (message_count).
|
||||
results = self.channel._get_bulk(self.queue_name)
|
||||
assert 5 == len(results)
|
||||
for i, r in enumerate(results):
|
||||
self.channel.qos.append(r, i)
|
||||
self.channel.connection._deliver = Mock(name='_deliver')
|
||||
self.channel._get_bulk(self.queue_name)
|
||||
assert self.channel.connection._deliver.call_count == 5
|
||||
for i in range(5):
|
||||
self.channel.qos.append(Mock(name='message{0}'.format(i)), i)
|
||||
|
||||
# Now, do the get again, the number of messages returned should be 1.
|
||||
results = self.channel._get_bulk(self.queue_name)
|
||||
assert len(results) == 1
|
||||
self.channel.connection._deliver.reset_mock()
|
||||
self.channel._get_bulk(self.queue_name)
|
||||
self.channel.connection._deliver.assert_called_once()
|
||||
|
||||
def test_drain_events_with_empty_list(self):
|
||||
def mock_can_consume():
|
||||
return False
|
||||
self.channel.qos.can_consume = mock_can_consume
|
||||
with pytest.raises(five.Empty):
|
||||
with pytest.raises(Empty):
|
||||
self.channel.drain_events()
|
||||
|
||||
def test_drain_events_with_prefetch_5(self):
|
||||
# Generate 20 messages
|
||||
message_count = 20
|
||||
expected_get_message_count = 4
|
||||
prefetch_count = 5
|
||||
|
||||
current_delivery_tag = [1]
|
||||
|
||||
# Set the prefetch_count to 5
|
||||
self.channel.qos.prefetch_count = 5
|
||||
self.channel.qos.prefetch_count = prefetch_count
|
||||
self.channel.connection._deliver = Mock(name='_deliver')
|
||||
|
||||
def on_message_delivered(message, queue):
|
||||
current_delivery_tag[0] += 1
|
||||
self.channel.qos.append(message, current_delivery_tag[0])
|
||||
self.channel.connection._deliver.side_effect = on_message_delivered
|
||||
|
||||
# Now, generate all the messages
|
||||
for i in range(message_count):
|
||||
self.producer.publish('message: %s' % i)
|
||||
|
||||
# Now drain all the events
|
||||
for i in range(message_count):
|
||||
self.channel.drain_events()
|
||||
for i in range(1000):
|
||||
try:
|
||||
self.channel.drain_events(timeout=0)
|
||||
except Empty:
|
||||
break
|
||||
else:
|
||||
assert False, 'disabled infinite loop'
|
||||
|
||||
# How many times was the SQSConnectionMock get_message method called?
|
||||
assert (expected_get_message_count ==
|
||||
self.channel._queue_cache[self.queue_name]._get_message_calls)
|
||||
self.channel.qos._flush()
|
||||
assert len(self.channel.qos._delivered) == prefetch_count
|
||||
|
||||
assert self.channel.connection._deliver.call_count == prefetch_count
|
||||
|
||||
def test_drain_events_with_prefetch_none(self):
|
||||
# Generate 20 messages
|
||||
message_count = 20
|
||||
expected_get_message_count = 2
|
||||
expected_get_message_count = 3
|
||||
|
||||
current_delivery_tag = [1]
|
||||
|
||||
# Set the prefetch_count to None
|
||||
self.channel.qos.prefetch_count = None
|
||||
self.channel.connection._deliver = Mock(name='_deliver')
|
||||
|
||||
def on_message_delivered(message, queue):
|
||||
current_delivery_tag[0] += 1
|
||||
self.channel.qos.append(message, current_delivery_tag[0])
|
||||
self.channel.connection._deliver.side_effect = on_message_delivered
|
||||
|
||||
# Now, generate all the messages
|
||||
for i in range(message_count):
|
||||
self.producer.publish('message: %s' % i)
|
||||
|
||||
# Now drain all the events
|
||||
for i in range(message_count):
|
||||
self.channel.drain_events()
|
||||
for i in range(1000):
|
||||
try:
|
||||
self.channel.drain_events(timeout=0)
|
||||
except Empty:
|
||||
break
|
||||
else:
|
||||
assert False, 'disabled infinite loop'
|
||||
|
||||
assert self.channel.connection._deliver.call_count == message_count
|
||||
|
||||
# How many times was the SQSConnectionMock get_message method called?
|
||||
assert (expected_get_message_count ==
|
||||
|
|
|
@ -150,7 +150,7 @@ class test_MemoryTransport:
|
|||
|
||||
class Cycle(object):
|
||||
|
||||
def get(self, timeout=None):
|
||||
def get(self, callback, timeout=None):
|
||||
return (message, 'foo'), c1
|
||||
|
||||
self.c.transport.cycle = Cycle()
|
||||
|
|
|
@ -486,11 +486,19 @@ class test_Channel:
|
|||
def test_receive(self):
|
||||
s = self.channel.subclient = Mock()
|
||||
self.channel._fanout_to_queue['a'] = 'b'
|
||||
s.parse_response.return_value = ['message', 'a',
|
||||
dumps({'hello': 'world'})]
|
||||
payload, queue = self.channel._receive()
|
||||
assert payload == {'hello': 'world'}
|
||||
assert queue == 'b'
|
||||
self.channel.connection._deliver = Mock(name='_deliver')
|
||||
message = {
|
||||
'body': 'hello',
|
||||
'properties': {
|
||||
'delivery_tag': 1,
|
||||
'delivery_info': {'exchange': 'E', 'routing_key': 'R'},
|
||||
},
|
||||
}
|
||||
s.parse_response.return_value = ['message', 'a', dumps(message)]
|
||||
self.channel._receive_one(self.channel.subclient)
|
||||
self.channel.connection._deliver.assert_called_once_with(
|
||||
message, 'b',
|
||||
)
|
||||
|
||||
def test_receive_raises_for_connection_error(self):
|
||||
self.channel._in_listen = True
|
||||
|
@ -498,22 +506,20 @@ class test_Channel:
|
|||
s.parse_response.side_effect = KeyError('foo')
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
self.channel._receive()
|
||||
self.channel._receive_one(self.channel.subclient)
|
||||
assert not self.channel._in_listen
|
||||
|
||||
def test_receive_empty(self):
|
||||
s = self.channel.subclient = Mock()
|
||||
s.parse_response.return_value = None
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
self.channel._receive()
|
||||
assert self.channel._receive_one(self.channel.subclient) is None
|
||||
|
||||
def test_receive_different_message_Type(self):
|
||||
s = self.channel.subclient = Mock()
|
||||
s.parse_response.return_value = ['message', '/foo/', 0, 'data']
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
self.channel._receive()
|
||||
assert self.channel._receive_one(self.channel.subclient) is None
|
||||
|
||||
def test_brpop_read_raises(self):
|
||||
c = self.channel.client = Mock()
|
||||
|
@ -1152,7 +1158,7 @@ class test_MultiChannelPoller:
|
|||
p, channel = self.create_get()
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
p.get()
|
||||
p.get(Mock())
|
||||
|
||||
def test_qos_reject(self):
|
||||
p, channel = self.create_get()
|
||||
|
@ -1166,7 +1172,7 @@ class test_MultiChannelPoller:
|
|||
channel.qos.can_consume.return_value = True
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
p.get()
|
||||
p.get(Mock())
|
||||
|
||||
p._register_BRPOP.assert_called_with(channel)
|
||||
|
||||
|
@ -1175,7 +1181,7 @@ class test_MultiChannelPoller:
|
|||
channel.qos.can_consume.return_value = False
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
p.get()
|
||||
p.get(Mock())
|
||||
|
||||
p._register_BRPOP.assert_not_called()
|
||||
|
||||
|
@ -1183,7 +1189,7 @@ class test_MultiChannelPoller:
|
|||
p, channel = self.create_get(fanouts=['f_queue'])
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
p.get()
|
||||
p.get(Mock())
|
||||
|
||||
p._register_LISTEN.assert_called_with(channel)
|
||||
|
||||
|
@ -1192,7 +1198,7 @@ class test_MultiChannelPoller:
|
|||
p._fd_to_chan[1] = (channel, 'BRPOP')
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
p.get()
|
||||
p.get(Mock())
|
||||
|
||||
channel._poll_error.assert_called_with('BRPOP')
|
||||
|
||||
|
@ -1202,7 +1208,7 @@ class test_MultiChannelPoller:
|
|||
p._fd_to_chan[1] = (channel, 'BRPOP')
|
||||
|
||||
with pytest.raises(redis.Empty):
|
||||
p.get()
|
||||
p.get(Mock())
|
||||
|
||||
channel._poll_error.assert_called_with('BRPOP')
|
||||
|
||||
|
|
|
@ -166,7 +166,7 @@ class test_AbstractChannel:
|
|||
|
||||
def test_poll(self):
|
||||
cycle = Mock(name='cycle')
|
||||
assert virtual.AbstractChannel()._poll(cycle)
|
||||
assert virtual.AbstractChannel()._poll(cycle, Mock())
|
||||
cycle.get.assert_called()
|
||||
|
||||
|
||||
|
@ -330,6 +330,12 @@ class test_Channel:
|
|||
c.queue_bind(n, n, n)
|
||||
c.queue_declare(n + '2')
|
||||
c.queue_bind(n + '2', n, n)
|
||||
messages = []
|
||||
c.connection._deliver = Mock(name='_deliver')
|
||||
|
||||
def on_deliver(message, queue):
|
||||
messages.append(message)
|
||||
c.connection._deliver.side_effect = on_deliver
|
||||
|
||||
m = c.prepare_message('nthex quick brown fox...')
|
||||
c.basic_publish(m, n, n)
|
||||
|
@ -344,8 +350,8 @@ class test_Channel:
|
|||
c.basic_consume(n + '2', False,
|
||||
consumer_tag=consumer_tag, callback=lambda *a: None)
|
||||
assert n + '2' in c._active_queues
|
||||
r2, _ = c.drain_events()
|
||||
r2 = c.message_to_python(r2)
|
||||
c.drain_events()
|
||||
r2 = c.message_to_python(messages[-1])
|
||||
assert r2.body == 'nthex quick brown fox...'.encode('utf-8')
|
||||
assert r2.delivery_info['exchange'] == n
|
||||
assert r2.delivery_info['routing_key'] == n
|
||||
|
@ -561,7 +567,7 @@ class test_Transport:
|
|||
def test_drain_channel(self):
|
||||
channel = self.transport.create_channel(self.transport)
|
||||
with pytest.raises(virtual.Empty):
|
||||
self.transport._drain_channel(channel)
|
||||
self.transport._drain_channel(channel, Mock())
|
||||
|
||||
def test__deliver__no_queue(self):
|
||||
with pytest.raises(KeyError):
|
||||
|
|
|
@ -2,6 +2,8 @@ from __future__ import absolute_import, unicode_literals
|
|||
|
||||
import pytest
|
||||
|
||||
from case import Mock
|
||||
|
||||
from kombu.utils.scheduling import FairCycle, cycle_by_name
|
||||
|
||||
|
||||
|
@ -12,7 +14,7 @@ class MyEmpty(Exception):
|
|||
def consume(fun, n):
|
||||
r = []
|
||||
for i in range(n):
|
||||
r.append(fun())
|
||||
r.append(fun(Mock(name='callback')))
|
||||
return r
|
||||
|
||||
|
||||
|
@ -20,6 +22,7 @@ class test_FairCycle:
|
|||
|
||||
def test_cycle(self):
|
||||
resources = ['a', 'b', 'c', 'd', 'e']
|
||||
callback = Mock(name='callback')
|
||||
|
||||
def echo(r, timeout=None):
|
||||
return r
|
||||
|
@ -27,26 +30,24 @@ class test_FairCycle:
|
|||
# cycle should be ['a', 'b', 'c', 'd', 'e', ... repeat]
|
||||
cycle = FairCycle(echo, resources, MyEmpty)
|
||||
for i in range(len(resources)):
|
||||
assert cycle.get() == (resources[i], resources[i])
|
||||
assert cycle.get(callback) == resources[i]
|
||||
for i in range(len(resources)):
|
||||
assert cycle.get() == (resources[i], resources[i])
|
||||
assert cycle.get(callback) == resources[i]
|
||||
|
||||
def test_cycle_breaks(self):
|
||||
resources = ['a', 'b', 'c', 'd', 'e']
|
||||
|
||||
def echo(r):
|
||||
def echo(r, callback):
|
||||
if r == 'c':
|
||||
raise MyEmpty(r)
|
||||
return r
|
||||
|
||||
cycle = FairCycle(echo, resources, MyEmpty)
|
||||
assert consume(cycle.get, len(resources)) == [
|
||||
('a', 'a'), ('b', 'b'), ('d', 'd'),
|
||||
('e', 'e'), ('a', 'a'),
|
||||
'a', 'b', 'd', 'e', 'a',
|
||||
]
|
||||
assert consume(cycle.get, len(resources)) == [
|
||||
('b', 'b'), ('d', 'd'), ('e', 'e'),
|
||||
('a', 'a'), ('b', 'b'),
|
||||
'b', 'd', 'e', 'a', 'b',
|
||||
]
|
||||
cycle2 = FairCycle(echo, ['c', 'c'], MyEmpty)
|
||||
with pytest.raises(MyEmpty):
|
||||
|
|
Loading…
Reference in New Issue