From 84fc45b9a76026c054ab1d31ec3883dd2e56c440 Mon Sep 17 00:00:00 2001 From: Ask Solem Date: Fri, 2 Sep 2016 15:35:39 -0700 Subject: [PATCH] Virtual transport deliver now calls callback, no return value. Fixes #593 --- kombu/transport/SQS.py | 28 ++------- kombu/transport/redis.py | 37 +++++++----- kombu/transport/virtual/base.py | 42 +++++++------ kombu/utils/scheduling.py | 12 ++-- t/unit/transport/test_SQS.py | 86 +++++++++++++++++++-------- t/unit/transport/test_memory.py | 2 +- t/unit/transport/test_redis.py | 38 +++++++----- t/unit/transport/virtual/test_base.py | 14 +++-- t/unit/utils/test_scheduling.py | 17 +++--- 9 files changed, 161 insertions(+), 115 deletions(-) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index eb261cea..0f1c65fb 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -37,7 +37,6 @@ SQS Features supported by this transport: from __future__ import absolute_import, unicode_literals -import collections import socket import string @@ -99,12 +98,6 @@ class Channel(virtual.Channel): # queues that are known to already exist. self._update_queue_cache(self.queue_name_prefix) - # The drain_events() method stores extra messages in a local - # Deque object. This allows multiple messages to be requested from - # SQS at once for performance, but maintains the same external API - # to the caller of the drain_events() method. - self._queue_message_cache = collections.deque() - self.hub = kwargs.get('hub') or get_event_loop() def _update_queue_cache(self, queue_name_prefix): @@ -145,24 +138,9 @@ class Channel(virtual.Channel): # If we're not allowed to consume or have no consumers, raise Empty if not self._consumers or not self.qos.can_consume(): raise Empty() - message_cache = self._queue_message_cache - - # Check if there are any items in our buffer. If there are any, pop - # off that queue first. - try: - return message_cache.popleft() - except IndexError: - pass # At this point, go and get more messages from SQS - res, queue = self._poll(self.cycle, timeout=timeout) - message_cache.extend((r, queue) for r in res) - - # Now try to pop off the queue again. - try: - return message_cache.popleft() - except IndexError: - raise Empty() + self._poll(self.cycle, self.connection._deliver, timeout=timeout) def _reset_cycle(self): """Reset the consume cycle. @@ -286,7 +264,9 @@ class Channel(virtual.Channel): messages = q.get_messages(num_messages=maxcount) if messages: - return self._messages_to_python(messages, queue) + for msg in self._messages_to_python(messages, queue): + self.connection._deliver(msg, queue) + return raise Empty() def _get(self, queue): diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index d7546a60..0b4272e1 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -325,7 +325,7 @@ class MultiChannelPoller(object): def on_readable(self, fileno): chan, type = self._fd_to_chan[fileno] if chan.qos.can_consume(): - return chan.handlers[type]() + chan.handlers[type]() def handle_event(self, fileno, event): if event & READ: @@ -334,7 +334,7 @@ class MultiChannelPoller(object): chan, type = self._fd_to_chan[fileno] chan._poll_error(type) - def get(self, timeout=None): + def get(self, callback, timeout=None): self._in_protected_read = True try: for channel in self._channels: @@ -345,15 +345,14 @@ class MultiChannelPoller(object): self._register_LISTEN(channel) events = self.poller.poll(timeout) - for fileno, event in events or []: - ret = self.handle_event(fileno, event) - if ret: - return ret - + if events: + for fileno, event in events: + ret = self.handle_event(fileno, event) + if ret: + return # - no new data, so try to restore messages. # - reset active redis commands. self.maybe_restore_messages() - raise Empty() finally: self._in_protected_read = False @@ -660,6 +659,16 @@ class Channel(virtual.Channel): def _receive(self): c = self.subclient + ret = [] + try: + ret.append(self._receive_one(c)) + except Empty: + pass + while c.connection.can_read(timeout=0): + ret.append(self._receive_one(c)) + return any(ret) + + def _receive_one(self, c): response = None try: response = c.parse_response() @@ -680,8 +689,9 @@ class Channel(virtual.Channel): channel, repr(payload)[:4096], exc_info=1) raise Empty() exchange = channel.split('/', 1)[0] - return message, self._fanout_to_queue[exchange] - raise Empty() + self.connection._deliver( + message, self._fanout_to_queue[exchange]) + return True def _brpop_start(self, timeout=1): queues = self._queue_cycle.consume(len(self.active_queues)) @@ -707,7 +717,8 @@ class Channel(virtual.Channel): dest, item = dest__item dest = bytes_to_str(dest).rsplit(self.sep, 1)[0] self._queue_cycle.rotate(dest) - return loads(bytes_to_str(item)), dest + self.connection._deliver(loads(bytes_to_str(item)), dest) + return True else: raise Empty() finally: @@ -1033,9 +1044,7 @@ class Transport(virtual.Transport): def on_readable(self, fileno): """Handle AIO event for one of our file descriptors.""" - item = self.cycle.on_readable(fileno) - if item: - self._deliver(*item) + self.cycle.on_readable(fileno) def _get_errors(self): """Utility to import redis-py's exceptions at runtime.""" diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py index ccfab350..372a32e9 100644 --- a/kombu/transport/virtual/base.py +++ b/kombu/transport/virtual/base.py @@ -393,9 +393,13 @@ class AbstractChannel(object): """ return True - def _poll(self, cycle, timeout=None): + def _poll(self, cycle, callback, timeout=None): """Poll a list of queues for available messages.""" - return cycle.get() + return cycle.get(callback) + + def _get_and_deliver(self, queue, callback): + message = self._get(queue) + callback(message, queue) class Channel(AbstractChannel, base.StdChannel): @@ -590,6 +594,15 @@ class Channel(AbstractChannel, base.StdChannel): def basic_publish(self, message, exchange, routing_key, **kwargs): """Publish message.""" + self._inplace_augment_message(message, exchange, routing_key) + if exchange: + return self.typeof(exchange).deliver( + message, exchange, routing_key, **kwargs + ) + # anon exchange: routing_key is the destination queue + return self._put(routing_key, message, **kwargs) + + def _inplace_augment_message(self, message, exchange, routing_key): message['body'], body_encoding = self.encode_body( message['body'], self.body_encoding, ) @@ -602,12 +615,6 @@ class Channel(AbstractChannel, base.StdChannel): exchange=exchange, routing_key=routing_key, ) - if exchange: - return self.typeof(exchange).deliver( - message, exchange, routing_key, **kwargs - ) - # anon exchange: routing_key is the destination queue - return self._put(routing_key, message, **kwargs) def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): """Consume from `queue`""" @@ -725,11 +732,12 @@ class Channel(AbstractChannel, base.StdChannel): def _restore_at_beginning(self, message): return self._restore(message) - def drain_events(self, timeout=None): + def drain_events(self, timeout=None, callback=None): + callback = callback or self.connection._deliver if self._consumers and self.qos.can_consume(): if hasattr(self, '_get_many'): return self._get_many(self._active_queues, timeout=timeout) - return self._poll(self.cycle, timeout=timeout) + return self._poll(self.cycle, callback, timeout=timeout) raise Empty() def message_to_python(self, raw_message): @@ -787,7 +795,8 @@ class Channel(AbstractChannel, base.StdChannel): return body def _reset_cycle(self): - self._cycle = FairCycle(self._get, self._active_queues, Empty) + self._cycle = FairCycle( + self._get_and_deliver, self._active_queues, Empty) def __enter__(self): return self @@ -935,22 +944,19 @@ class Transport(base.Transport): channel.close() def drain_events(self, connection, timeout=None): - loop = 0 time_start = monotonic() get = self.cycle.get polling_interval = self.polling_interval while 1: try: - item, channel = get(timeout=timeout) + get(self._deliver, timeout=timeout) except Empty: - if timeout and monotonic() - time_start >= timeout: + if timeout is not None and monotonic() - time_start >= timeout: raise socket.timeout() - loop += 1 if polling_interval is not None: sleep(polling_interval) else: break - self._deliver(*item) def _deliver(self, message, queue): if not queue: @@ -980,8 +986,8 @@ class Transport(base.Transport): queue, message)) self._callbacks[queue](message) - def _drain_channel(self, channel, timeout=None): - return channel.drain_events(timeout=timeout) + def _drain_channel(self, channel, callback, timeout=None): + return channel.drain_events(callback=callback, timeout=timeout) @property def default_connection_params(self): diff --git a/kombu/utils/scheduling.py b/kombu/utils/scheduling.py index f44222db..e5833fa1 100644 --- a/kombu/utils/scheduling.py +++ b/kombu/utils/scheduling.py @@ -40,15 +40,19 @@ class FairCycle(object): if not self.resources: raise self.predicate() - def get(self, **kwargs): + def get(self, callback, **kwargs): + succeeded = 0 for tried in count(0): # for infinity resource = self._next() - try: - return self.fun(resource, **kwargs), resource + return self.fun(resource, callback, **kwargs) except self.predicate: if tried >= len(self.resources) - 1: - raise + if not succeeded: + raise + break + else: + succeeded += 1 def close(self): pass diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 5032d43c..12d21a49 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -9,14 +9,14 @@ from __future__ import absolute_import, unicode_literals import pytest -from case import skip +from case import Mock, skip -from kombu import five from kombu import messaging from kombu import Connection, Exchange, Queue -from kombu.transport import SQS from kombu.async.aws.ext import exception +from kombu.five import Empty +from kombu.transport import SQS class SQSQueueMock(object): @@ -48,7 +48,9 @@ class SQSQueueMock(object): def get_messages(self, num_messages=1, visibility_timeout=None, attributes=None, *args, **kwargs): self._get_message_calls += 1 - return self.messages[:num_messages] + messages, self.messages[:num_messages] = ( + self.messages[:num_messages], []) + return messages def read(self, visibility_timeout=None): return self.messages.pop(0) @@ -222,11 +224,11 @@ class test_Channel: assert len(results) == 3 def test_get_with_empty_list(self): - with pytest.raises(five.Empty): + with pytest.raises(Empty): self.channel._get(self.queue_name) def test_get_bulk_raises_empty(self): - with pytest.raises(five.Empty): + with pytest.raises(Empty): self.channel._get_bulk(self.queue_name) def test_messages_to_python(self): @@ -239,7 +241,7 @@ class test_Channel: # json formatted message NOT created by kombu for i in range(json_message_count): - message = '{"foo":"bar"}' + message = {'foo': 'bar'} self.channel._put(self.producer.routing_key, message) q = self.channel._new_queue(self.queue_name) @@ -275,8 +277,9 @@ class test_Channel: # With QoS.prefetch_count = 0 message = 'my test message' self.producer.publish(message) - results = self.channel._get_bulk(self.queue_name) - assert 1 == len(results) + self.channel.connection._deliver = Mock(name='_deliver') + self.channel._get_bulk(self.queue_name) + self.channel.connection._deliver.assert_called_once() def test_puts_and_get_bulk(self): # Generate 8 messages @@ -292,57 +295,88 @@ class test_Channel: # Count how many messages are retrieved the first time. Should # be 5 (message_count). - results = self.channel._get_bulk(self.queue_name) - assert 5 == len(results) - for i, r in enumerate(results): - self.channel.qos.append(r, i) + self.channel.connection._deliver = Mock(name='_deliver') + self.channel._get_bulk(self.queue_name) + assert self.channel.connection._deliver.call_count == 5 + for i in range(5): + self.channel.qos.append(Mock(name='message{0}'.format(i)), i) # Now, do the get again, the number of messages returned should be 1. - results = self.channel._get_bulk(self.queue_name) - assert len(results) == 1 + self.channel.connection._deliver.reset_mock() + self.channel._get_bulk(self.queue_name) + self.channel.connection._deliver.assert_called_once() def test_drain_events_with_empty_list(self): def mock_can_consume(): return False self.channel.qos.can_consume = mock_can_consume - with pytest.raises(five.Empty): + with pytest.raises(Empty): self.channel.drain_events() def test_drain_events_with_prefetch_5(self): # Generate 20 messages message_count = 20 - expected_get_message_count = 4 + prefetch_count = 5 + + current_delivery_tag = [1] # Set the prefetch_count to 5 - self.channel.qos.prefetch_count = 5 + self.channel.qos.prefetch_count = prefetch_count + self.channel.connection._deliver = Mock(name='_deliver') + + def on_message_delivered(message, queue): + current_delivery_tag[0] += 1 + self.channel.qos.append(message, current_delivery_tag[0]) + self.channel.connection._deliver.side_effect = on_message_delivered # Now, generate all the messages for i in range(message_count): self.producer.publish('message: %s' % i) # Now drain all the events - for i in range(message_count): - self.channel.drain_events() + for i in range(1000): + try: + self.channel.drain_events(timeout=0) + except Empty: + break + else: + assert False, 'disabled infinite loop' - # How many times was the SQSConnectionMock get_message method called? - assert (expected_get_message_count == - self.channel._queue_cache[self.queue_name]._get_message_calls) + self.channel.qos._flush() + assert len(self.channel.qos._delivered) == prefetch_count + + assert self.channel.connection._deliver.call_count == prefetch_count def test_drain_events_with_prefetch_none(self): # Generate 20 messages message_count = 20 - expected_get_message_count = 2 + expected_get_message_count = 3 + + current_delivery_tag = [1] # Set the prefetch_count to None self.channel.qos.prefetch_count = None + self.channel.connection._deliver = Mock(name='_deliver') + + def on_message_delivered(message, queue): + current_delivery_tag[0] += 1 + self.channel.qos.append(message, current_delivery_tag[0]) + self.channel.connection._deliver.side_effect = on_message_delivered # Now, generate all the messages for i in range(message_count): self.producer.publish('message: %s' % i) # Now drain all the events - for i in range(message_count): - self.channel.drain_events() + for i in range(1000): + try: + self.channel.drain_events(timeout=0) + except Empty: + break + else: + assert False, 'disabled infinite loop' + + assert self.channel.connection._deliver.call_count == message_count # How many times was the SQSConnectionMock get_message method called? assert (expected_get_message_count == diff --git a/t/unit/transport/test_memory.py b/t/unit/transport/test_memory.py index d4dc3390..5d3bf6fc 100644 --- a/t/unit/transport/test_memory.py +++ b/t/unit/transport/test_memory.py @@ -150,7 +150,7 @@ class test_MemoryTransport: class Cycle(object): - def get(self, timeout=None): + def get(self, callback, timeout=None): return (message, 'foo'), c1 self.c.transport.cycle = Cycle() diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index fd47317f..64bf5587 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -486,11 +486,19 @@ class test_Channel: def test_receive(self): s = self.channel.subclient = Mock() self.channel._fanout_to_queue['a'] = 'b' - s.parse_response.return_value = ['message', 'a', - dumps({'hello': 'world'})] - payload, queue = self.channel._receive() - assert payload == {'hello': 'world'} - assert queue == 'b' + self.channel.connection._deliver = Mock(name='_deliver') + message = { + 'body': 'hello', + 'properties': { + 'delivery_tag': 1, + 'delivery_info': {'exchange': 'E', 'routing_key': 'R'}, + }, + } + s.parse_response.return_value = ['message', 'a', dumps(message)] + self.channel._receive_one(self.channel.subclient) + self.channel.connection._deliver.assert_called_once_with( + message, 'b', + ) def test_receive_raises_for_connection_error(self): self.channel._in_listen = True @@ -498,22 +506,20 @@ class test_Channel: s.parse_response.side_effect = KeyError('foo') with pytest.raises(KeyError): - self.channel._receive() + self.channel._receive_one(self.channel.subclient) assert not self.channel._in_listen def test_receive_empty(self): s = self.channel.subclient = Mock() s.parse_response.return_value = None - with pytest.raises(redis.Empty): - self.channel._receive() + assert self.channel._receive_one(self.channel.subclient) is None def test_receive_different_message_Type(self): s = self.channel.subclient = Mock() s.parse_response.return_value = ['message', '/foo/', 0, 'data'] - with pytest.raises(redis.Empty): - self.channel._receive() + assert self.channel._receive_one(self.channel.subclient) is None def test_brpop_read_raises(self): c = self.channel.client = Mock() @@ -1152,7 +1158,7 @@ class test_MultiChannelPoller: p, channel = self.create_get() with pytest.raises(redis.Empty): - p.get() + p.get(Mock()) def test_qos_reject(self): p, channel = self.create_get() @@ -1166,7 +1172,7 @@ class test_MultiChannelPoller: channel.qos.can_consume.return_value = True with pytest.raises(redis.Empty): - p.get() + p.get(Mock()) p._register_BRPOP.assert_called_with(channel) @@ -1175,7 +1181,7 @@ class test_MultiChannelPoller: channel.qos.can_consume.return_value = False with pytest.raises(redis.Empty): - p.get() + p.get(Mock()) p._register_BRPOP.assert_not_called() @@ -1183,7 +1189,7 @@ class test_MultiChannelPoller: p, channel = self.create_get(fanouts=['f_queue']) with pytest.raises(redis.Empty): - p.get() + p.get(Mock()) p._register_LISTEN.assert_called_with(channel) @@ -1192,7 +1198,7 @@ class test_MultiChannelPoller: p._fd_to_chan[1] = (channel, 'BRPOP') with pytest.raises(redis.Empty): - p.get() + p.get(Mock()) channel._poll_error.assert_called_with('BRPOP') @@ -1202,7 +1208,7 @@ class test_MultiChannelPoller: p._fd_to_chan[1] = (channel, 'BRPOP') with pytest.raises(redis.Empty): - p.get() + p.get(Mock()) channel._poll_error.assert_called_with('BRPOP') diff --git a/t/unit/transport/virtual/test_base.py b/t/unit/transport/virtual/test_base.py index 5eb23a08..f4acbda0 100644 --- a/t/unit/transport/virtual/test_base.py +++ b/t/unit/transport/virtual/test_base.py @@ -166,7 +166,7 @@ class test_AbstractChannel: def test_poll(self): cycle = Mock(name='cycle') - assert virtual.AbstractChannel()._poll(cycle) + assert virtual.AbstractChannel()._poll(cycle, Mock()) cycle.get.assert_called() @@ -330,6 +330,12 @@ class test_Channel: c.queue_bind(n, n, n) c.queue_declare(n + '2') c.queue_bind(n + '2', n, n) + messages = [] + c.connection._deliver = Mock(name='_deliver') + + def on_deliver(message, queue): + messages.append(message) + c.connection._deliver.side_effect = on_deliver m = c.prepare_message('nthex quick brown fox...') c.basic_publish(m, n, n) @@ -344,8 +350,8 @@ class test_Channel: c.basic_consume(n + '2', False, consumer_tag=consumer_tag, callback=lambda *a: None) assert n + '2' in c._active_queues - r2, _ = c.drain_events() - r2 = c.message_to_python(r2) + c.drain_events() + r2 = c.message_to_python(messages[-1]) assert r2.body == 'nthex quick brown fox...'.encode('utf-8') assert r2.delivery_info['exchange'] == n assert r2.delivery_info['routing_key'] == n @@ -561,7 +567,7 @@ class test_Transport: def test_drain_channel(self): channel = self.transport.create_channel(self.transport) with pytest.raises(virtual.Empty): - self.transport._drain_channel(channel) + self.transport._drain_channel(channel, Mock()) def test__deliver__no_queue(self): with pytest.raises(KeyError): diff --git a/t/unit/utils/test_scheduling.py b/t/unit/utils/test_scheduling.py index 894286cd..79216b56 100644 --- a/t/unit/utils/test_scheduling.py +++ b/t/unit/utils/test_scheduling.py @@ -2,6 +2,8 @@ from __future__ import absolute_import, unicode_literals import pytest +from case import Mock + from kombu.utils.scheduling import FairCycle, cycle_by_name @@ -12,7 +14,7 @@ class MyEmpty(Exception): def consume(fun, n): r = [] for i in range(n): - r.append(fun()) + r.append(fun(Mock(name='callback'))) return r @@ -20,6 +22,7 @@ class test_FairCycle: def test_cycle(self): resources = ['a', 'b', 'c', 'd', 'e'] + callback = Mock(name='callback') def echo(r, timeout=None): return r @@ -27,26 +30,24 @@ class test_FairCycle: # cycle should be ['a', 'b', 'c', 'd', 'e', ... repeat] cycle = FairCycle(echo, resources, MyEmpty) for i in range(len(resources)): - assert cycle.get() == (resources[i], resources[i]) + assert cycle.get(callback) == resources[i] for i in range(len(resources)): - assert cycle.get() == (resources[i], resources[i]) + assert cycle.get(callback) == resources[i] def test_cycle_breaks(self): resources = ['a', 'b', 'c', 'd', 'e'] - def echo(r): + def echo(r, callback): if r == 'c': raise MyEmpty(r) return r cycle = FairCycle(echo, resources, MyEmpty) assert consume(cycle.get, len(resources)) == [ - ('a', 'a'), ('b', 'b'), ('d', 'd'), - ('e', 'e'), ('a', 'a'), + 'a', 'b', 'd', 'e', 'a', ] assert consume(cycle.get, len(resources)) == [ - ('b', 'b'), ('d', 'd'), ('e', 'e'), - ('a', 'a'), ('b', 'b'), + 'b', 'd', 'e', 'a', 'b', ] cycle2 = FairCycle(echo, ['c', 'c'], MyEmpty) with pytest.raises(MyEmpty):