Virtual transport deliver now calls callback, no return value. Fixes #593

This commit is contained in:
Ask Solem 2016-09-02 15:35:39 -07:00
parent f71ea5c803
commit 84fc45b9a7
9 changed files with 161 additions and 115 deletions

View File

@ -37,7 +37,6 @@ SQS Features supported by this transport:
from __future__ import absolute_import, unicode_literals
import collections
import socket
import string
@ -99,12 +98,6 @@ class Channel(virtual.Channel):
# queues that are known to already exist.
self._update_queue_cache(self.queue_name_prefix)
# The drain_events() method stores extra messages in a local
# Deque object. This allows multiple messages to be requested from
# SQS at once for performance, but maintains the same external API
# to the caller of the drain_events() method.
self._queue_message_cache = collections.deque()
self.hub = kwargs.get('hub') or get_event_loop()
def _update_queue_cache(self, queue_name_prefix):
@ -145,24 +138,9 @@ class Channel(virtual.Channel):
# If we're not allowed to consume or have no consumers, raise Empty
if not self._consumers or not self.qos.can_consume():
raise Empty()
message_cache = self._queue_message_cache
# Check if there are any items in our buffer. If there are any, pop
# off that queue first.
try:
return message_cache.popleft()
except IndexError:
pass
# At this point, go and get more messages from SQS
res, queue = self._poll(self.cycle, timeout=timeout)
message_cache.extend((r, queue) for r in res)
# Now try to pop off the queue again.
try:
return message_cache.popleft()
except IndexError:
raise Empty()
self._poll(self.cycle, self.connection._deliver, timeout=timeout)
def _reset_cycle(self):
"""Reset the consume cycle.
@ -286,7 +264,9 @@ class Channel(virtual.Channel):
messages = q.get_messages(num_messages=maxcount)
if messages:
return self._messages_to_python(messages, queue)
for msg in self._messages_to_python(messages, queue):
self.connection._deliver(msg, queue)
return
raise Empty()
def _get(self, queue):

View File

@ -325,7 +325,7 @@ class MultiChannelPoller(object):
def on_readable(self, fileno):
chan, type = self._fd_to_chan[fileno]
if chan.qos.can_consume():
return chan.handlers[type]()
chan.handlers[type]()
def handle_event(self, fileno, event):
if event & READ:
@ -334,7 +334,7 @@ class MultiChannelPoller(object):
chan, type = self._fd_to_chan[fileno]
chan._poll_error(type)
def get(self, timeout=None):
def get(self, callback, timeout=None):
self._in_protected_read = True
try:
for channel in self._channels:
@ -345,15 +345,14 @@ class MultiChannelPoller(object):
self._register_LISTEN(channel)
events = self.poller.poll(timeout)
for fileno, event in events or []:
ret = self.handle_event(fileno, event)
if ret:
return ret
if events:
for fileno, event in events:
ret = self.handle_event(fileno, event)
if ret:
return
# - no new data, so try to restore messages.
# - reset active redis commands.
self.maybe_restore_messages()
raise Empty()
finally:
self._in_protected_read = False
@ -660,6 +659,16 @@ class Channel(virtual.Channel):
def _receive(self):
c = self.subclient
ret = []
try:
ret.append(self._receive_one(c))
except Empty:
pass
while c.connection.can_read(timeout=0):
ret.append(self._receive_one(c))
return any(ret)
def _receive_one(self, c):
response = None
try:
response = c.parse_response()
@ -680,8 +689,9 @@ class Channel(virtual.Channel):
channel, repr(payload)[:4096], exc_info=1)
raise Empty()
exchange = channel.split('/', 1)[0]
return message, self._fanout_to_queue[exchange]
raise Empty()
self.connection._deliver(
message, self._fanout_to_queue[exchange])
return True
def _brpop_start(self, timeout=1):
queues = self._queue_cycle.consume(len(self.active_queues))
@ -707,7 +717,8 @@ class Channel(virtual.Channel):
dest, item = dest__item
dest = bytes_to_str(dest).rsplit(self.sep, 1)[0]
self._queue_cycle.rotate(dest)
return loads(bytes_to_str(item)), dest
self.connection._deliver(loads(bytes_to_str(item)), dest)
return True
else:
raise Empty()
finally:
@ -1033,9 +1044,7 @@ class Transport(virtual.Transport):
def on_readable(self, fileno):
"""Handle AIO event for one of our file descriptors."""
item = self.cycle.on_readable(fileno)
if item:
self._deliver(*item)
self.cycle.on_readable(fileno)
def _get_errors(self):
"""Utility to import redis-py's exceptions at runtime."""

View File

@ -393,9 +393,13 @@ class AbstractChannel(object):
"""
return True
def _poll(self, cycle, timeout=None):
def _poll(self, cycle, callback, timeout=None):
"""Poll a list of queues for available messages."""
return cycle.get()
return cycle.get(callback)
def _get_and_deliver(self, queue, callback):
message = self._get(queue)
callback(message, queue)
class Channel(AbstractChannel, base.StdChannel):
@ -590,6 +594,15 @@ class Channel(AbstractChannel, base.StdChannel):
def basic_publish(self, message, exchange, routing_key, **kwargs):
"""Publish message."""
self._inplace_augment_message(message, exchange, routing_key)
if exchange:
return self.typeof(exchange).deliver(
message, exchange, routing_key, **kwargs
)
# anon exchange: routing_key is the destination queue
return self._put(routing_key, message, **kwargs)
def _inplace_augment_message(self, message, exchange, routing_key):
message['body'], body_encoding = self.encode_body(
message['body'], self.body_encoding,
)
@ -602,12 +615,6 @@ class Channel(AbstractChannel, base.StdChannel):
exchange=exchange,
routing_key=routing_key,
)
if exchange:
return self.typeof(exchange).deliver(
message, exchange, routing_key, **kwargs
)
# anon exchange: routing_key is the destination queue
return self._put(routing_key, message, **kwargs)
def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs):
"""Consume from `queue`"""
@ -725,11 +732,12 @@ class Channel(AbstractChannel, base.StdChannel):
def _restore_at_beginning(self, message):
return self._restore(message)
def drain_events(self, timeout=None):
def drain_events(self, timeout=None, callback=None):
callback = callback or self.connection._deliver
if self._consumers and self.qos.can_consume():
if hasattr(self, '_get_many'):
return self._get_many(self._active_queues, timeout=timeout)
return self._poll(self.cycle, timeout=timeout)
return self._poll(self.cycle, callback, timeout=timeout)
raise Empty()
def message_to_python(self, raw_message):
@ -787,7 +795,8 @@ class Channel(AbstractChannel, base.StdChannel):
return body
def _reset_cycle(self):
self._cycle = FairCycle(self._get, self._active_queues, Empty)
self._cycle = FairCycle(
self._get_and_deliver, self._active_queues, Empty)
def __enter__(self):
return self
@ -935,22 +944,19 @@ class Transport(base.Transport):
channel.close()
def drain_events(self, connection, timeout=None):
loop = 0
time_start = monotonic()
get = self.cycle.get
polling_interval = self.polling_interval
while 1:
try:
item, channel = get(timeout=timeout)
get(self._deliver, timeout=timeout)
except Empty:
if timeout and monotonic() - time_start >= timeout:
if timeout is not None and monotonic() - time_start >= timeout:
raise socket.timeout()
loop += 1
if polling_interval is not None:
sleep(polling_interval)
else:
break
self._deliver(*item)
def _deliver(self, message, queue):
if not queue:
@ -980,8 +986,8 @@ class Transport(base.Transport):
queue, message))
self._callbacks[queue](message)
def _drain_channel(self, channel, timeout=None):
return channel.drain_events(timeout=timeout)
def _drain_channel(self, channel, callback, timeout=None):
return channel.drain_events(callback=callback, timeout=timeout)
@property
def default_connection_params(self):

View File

@ -40,15 +40,19 @@ class FairCycle(object):
if not self.resources:
raise self.predicate()
def get(self, **kwargs):
def get(self, callback, **kwargs):
succeeded = 0
for tried in count(0): # for infinity
resource = self._next()
try:
return self.fun(resource, **kwargs), resource
return self.fun(resource, callback, **kwargs)
except self.predicate:
if tried >= len(self.resources) - 1:
raise
if not succeeded:
raise
break
else:
succeeded += 1
def close(self):
pass

View File

@ -9,14 +9,14 @@ from __future__ import absolute_import, unicode_literals
import pytest
from case import skip
from case import Mock, skip
from kombu import five
from kombu import messaging
from kombu import Connection, Exchange, Queue
from kombu.transport import SQS
from kombu.async.aws.ext import exception
from kombu.five import Empty
from kombu.transport import SQS
class SQSQueueMock(object):
@ -48,7 +48,9 @@ class SQSQueueMock(object):
def get_messages(self, num_messages=1, visibility_timeout=None,
attributes=None, *args, **kwargs):
self._get_message_calls += 1
return self.messages[:num_messages]
messages, self.messages[:num_messages] = (
self.messages[:num_messages], [])
return messages
def read(self, visibility_timeout=None):
return self.messages.pop(0)
@ -222,11 +224,11 @@ class test_Channel:
assert len(results) == 3
def test_get_with_empty_list(self):
with pytest.raises(five.Empty):
with pytest.raises(Empty):
self.channel._get(self.queue_name)
def test_get_bulk_raises_empty(self):
with pytest.raises(five.Empty):
with pytest.raises(Empty):
self.channel._get_bulk(self.queue_name)
def test_messages_to_python(self):
@ -239,7 +241,7 @@ class test_Channel:
# json formatted message NOT created by kombu
for i in range(json_message_count):
message = '{"foo":"bar"}'
message = {'foo': 'bar'}
self.channel._put(self.producer.routing_key, message)
q = self.channel._new_queue(self.queue_name)
@ -275,8 +277,9 @@ class test_Channel:
# With QoS.prefetch_count = 0
message = 'my test message'
self.producer.publish(message)
results = self.channel._get_bulk(self.queue_name)
assert 1 == len(results)
self.channel.connection._deliver = Mock(name='_deliver')
self.channel._get_bulk(self.queue_name)
self.channel.connection._deliver.assert_called_once()
def test_puts_and_get_bulk(self):
# Generate 8 messages
@ -292,57 +295,88 @@ class test_Channel:
# Count how many messages are retrieved the first time. Should
# be 5 (message_count).
results = self.channel._get_bulk(self.queue_name)
assert 5 == len(results)
for i, r in enumerate(results):
self.channel.qos.append(r, i)
self.channel.connection._deliver = Mock(name='_deliver')
self.channel._get_bulk(self.queue_name)
assert self.channel.connection._deliver.call_count == 5
for i in range(5):
self.channel.qos.append(Mock(name='message{0}'.format(i)), i)
# Now, do the get again, the number of messages returned should be 1.
results = self.channel._get_bulk(self.queue_name)
assert len(results) == 1
self.channel.connection._deliver.reset_mock()
self.channel._get_bulk(self.queue_name)
self.channel.connection._deliver.assert_called_once()
def test_drain_events_with_empty_list(self):
def mock_can_consume():
return False
self.channel.qos.can_consume = mock_can_consume
with pytest.raises(five.Empty):
with pytest.raises(Empty):
self.channel.drain_events()
def test_drain_events_with_prefetch_5(self):
# Generate 20 messages
message_count = 20
expected_get_message_count = 4
prefetch_count = 5
current_delivery_tag = [1]
# Set the prefetch_count to 5
self.channel.qos.prefetch_count = 5
self.channel.qos.prefetch_count = prefetch_count
self.channel.connection._deliver = Mock(name='_deliver')
def on_message_delivered(message, queue):
current_delivery_tag[0] += 1
self.channel.qos.append(message, current_delivery_tag[0])
self.channel.connection._deliver.side_effect = on_message_delivered
# Now, generate all the messages
for i in range(message_count):
self.producer.publish('message: %s' % i)
# Now drain all the events
for i in range(message_count):
self.channel.drain_events()
for i in range(1000):
try:
self.channel.drain_events(timeout=0)
except Empty:
break
else:
assert False, 'disabled infinite loop'
# How many times was the SQSConnectionMock get_message method called?
assert (expected_get_message_count ==
self.channel._queue_cache[self.queue_name]._get_message_calls)
self.channel.qos._flush()
assert len(self.channel.qos._delivered) == prefetch_count
assert self.channel.connection._deliver.call_count == prefetch_count
def test_drain_events_with_prefetch_none(self):
# Generate 20 messages
message_count = 20
expected_get_message_count = 2
expected_get_message_count = 3
current_delivery_tag = [1]
# Set the prefetch_count to None
self.channel.qos.prefetch_count = None
self.channel.connection._deliver = Mock(name='_deliver')
def on_message_delivered(message, queue):
current_delivery_tag[0] += 1
self.channel.qos.append(message, current_delivery_tag[0])
self.channel.connection._deliver.side_effect = on_message_delivered
# Now, generate all the messages
for i in range(message_count):
self.producer.publish('message: %s' % i)
# Now drain all the events
for i in range(message_count):
self.channel.drain_events()
for i in range(1000):
try:
self.channel.drain_events(timeout=0)
except Empty:
break
else:
assert False, 'disabled infinite loop'
assert self.channel.connection._deliver.call_count == message_count
# How many times was the SQSConnectionMock get_message method called?
assert (expected_get_message_count ==

View File

@ -150,7 +150,7 @@ class test_MemoryTransport:
class Cycle(object):
def get(self, timeout=None):
def get(self, callback, timeout=None):
return (message, 'foo'), c1
self.c.transport.cycle = Cycle()

View File

@ -486,11 +486,19 @@ class test_Channel:
def test_receive(self):
s = self.channel.subclient = Mock()
self.channel._fanout_to_queue['a'] = 'b'
s.parse_response.return_value = ['message', 'a',
dumps({'hello': 'world'})]
payload, queue = self.channel._receive()
assert payload == {'hello': 'world'}
assert queue == 'b'
self.channel.connection._deliver = Mock(name='_deliver')
message = {
'body': 'hello',
'properties': {
'delivery_tag': 1,
'delivery_info': {'exchange': 'E', 'routing_key': 'R'},
},
}
s.parse_response.return_value = ['message', 'a', dumps(message)]
self.channel._receive_one(self.channel.subclient)
self.channel.connection._deliver.assert_called_once_with(
message, 'b',
)
def test_receive_raises_for_connection_error(self):
self.channel._in_listen = True
@ -498,22 +506,20 @@ class test_Channel:
s.parse_response.side_effect = KeyError('foo')
with pytest.raises(KeyError):
self.channel._receive()
self.channel._receive_one(self.channel.subclient)
assert not self.channel._in_listen
def test_receive_empty(self):
s = self.channel.subclient = Mock()
s.parse_response.return_value = None
with pytest.raises(redis.Empty):
self.channel._receive()
assert self.channel._receive_one(self.channel.subclient) is None
def test_receive_different_message_Type(self):
s = self.channel.subclient = Mock()
s.parse_response.return_value = ['message', '/foo/', 0, 'data']
with pytest.raises(redis.Empty):
self.channel._receive()
assert self.channel._receive_one(self.channel.subclient) is None
def test_brpop_read_raises(self):
c = self.channel.client = Mock()
@ -1152,7 +1158,7 @@ class test_MultiChannelPoller:
p, channel = self.create_get()
with pytest.raises(redis.Empty):
p.get()
p.get(Mock())
def test_qos_reject(self):
p, channel = self.create_get()
@ -1166,7 +1172,7 @@ class test_MultiChannelPoller:
channel.qos.can_consume.return_value = True
with pytest.raises(redis.Empty):
p.get()
p.get(Mock())
p._register_BRPOP.assert_called_with(channel)
@ -1175,7 +1181,7 @@ class test_MultiChannelPoller:
channel.qos.can_consume.return_value = False
with pytest.raises(redis.Empty):
p.get()
p.get(Mock())
p._register_BRPOP.assert_not_called()
@ -1183,7 +1189,7 @@ class test_MultiChannelPoller:
p, channel = self.create_get(fanouts=['f_queue'])
with pytest.raises(redis.Empty):
p.get()
p.get(Mock())
p._register_LISTEN.assert_called_with(channel)
@ -1192,7 +1198,7 @@ class test_MultiChannelPoller:
p._fd_to_chan[1] = (channel, 'BRPOP')
with pytest.raises(redis.Empty):
p.get()
p.get(Mock())
channel._poll_error.assert_called_with('BRPOP')
@ -1202,7 +1208,7 @@ class test_MultiChannelPoller:
p._fd_to_chan[1] = (channel, 'BRPOP')
with pytest.raises(redis.Empty):
p.get()
p.get(Mock())
channel._poll_error.assert_called_with('BRPOP')

View File

@ -166,7 +166,7 @@ class test_AbstractChannel:
def test_poll(self):
cycle = Mock(name='cycle')
assert virtual.AbstractChannel()._poll(cycle)
assert virtual.AbstractChannel()._poll(cycle, Mock())
cycle.get.assert_called()
@ -330,6 +330,12 @@ class test_Channel:
c.queue_bind(n, n, n)
c.queue_declare(n + '2')
c.queue_bind(n + '2', n, n)
messages = []
c.connection._deliver = Mock(name='_deliver')
def on_deliver(message, queue):
messages.append(message)
c.connection._deliver.side_effect = on_deliver
m = c.prepare_message('nthex quick brown fox...')
c.basic_publish(m, n, n)
@ -344,8 +350,8 @@ class test_Channel:
c.basic_consume(n + '2', False,
consumer_tag=consumer_tag, callback=lambda *a: None)
assert n + '2' in c._active_queues
r2, _ = c.drain_events()
r2 = c.message_to_python(r2)
c.drain_events()
r2 = c.message_to_python(messages[-1])
assert r2.body == 'nthex quick brown fox...'.encode('utf-8')
assert r2.delivery_info['exchange'] == n
assert r2.delivery_info['routing_key'] == n
@ -561,7 +567,7 @@ class test_Transport:
def test_drain_channel(self):
channel = self.transport.create_channel(self.transport)
with pytest.raises(virtual.Empty):
self.transport._drain_channel(channel)
self.transport._drain_channel(channel, Mock())
def test__deliver__no_queue(self):
with pytest.raises(KeyError):

View File

@ -2,6 +2,8 @@ from __future__ import absolute_import, unicode_literals
import pytest
from case import Mock
from kombu.utils.scheduling import FairCycle, cycle_by_name
@ -12,7 +14,7 @@ class MyEmpty(Exception):
def consume(fun, n):
r = []
for i in range(n):
r.append(fun())
r.append(fun(Mock(name='callback')))
return r
@ -20,6 +22,7 @@ class test_FairCycle:
def test_cycle(self):
resources = ['a', 'b', 'c', 'd', 'e']
callback = Mock(name='callback')
def echo(r, timeout=None):
return r
@ -27,26 +30,24 @@ class test_FairCycle:
# cycle should be ['a', 'b', 'c', 'd', 'e', ... repeat]
cycle = FairCycle(echo, resources, MyEmpty)
for i in range(len(resources)):
assert cycle.get() == (resources[i], resources[i])
assert cycle.get(callback) == resources[i]
for i in range(len(resources)):
assert cycle.get() == (resources[i], resources[i])
assert cycle.get(callback) == resources[i]
def test_cycle_breaks(self):
resources = ['a', 'b', 'c', 'd', 'e']
def echo(r):
def echo(r, callback):
if r == 'c':
raise MyEmpty(r)
return r
cycle = FairCycle(echo, resources, MyEmpty)
assert consume(cycle.get, len(resources)) == [
('a', 'a'), ('b', 'b'), ('d', 'd'),
('e', 'e'), ('a', 'a'),
'a', 'b', 'd', 'e', 'a',
]
assert consume(cycle.get, len(resources)) == [
('b', 'b'), ('d', 'd'), ('e', 'e'),
('a', 'a'), ('b', 'b'),
'b', 'd', 'e', 'a', 'b',
]
cycle2 = FairCycle(echo, ['c', 'c'], MyEmpty)
with pytest.raises(MyEmpty):