mirror of https://github.com/celery/kombu.git
SQS improvements
This commit is contained in:
parent
69b36c91a5
commit
af4d2b0cfd
|
@ -253,6 +253,7 @@ class Mailbox(object):
|
|||
self.connection.drain_events(timeout=timeout)
|
||||
except socket.timeout:
|
||||
break
|
||||
chan.after_reply_message_received(queue.name)
|
||||
return responses
|
||||
finally:
|
||||
channel or chan.close()
|
||||
|
|
|
@ -39,6 +39,8 @@ CHARS_REPLACE_TABLE = string.maketrans(CHARS_REPLACE + '.',
|
|||
|
||||
class Table(Domain):
|
||||
"""Amazon SimpleDB domain describing the message routing table."""
|
||||
# caches queues already bound, so we don't have to declare them again.
|
||||
_already_bound = set()
|
||||
|
||||
def routes_for(self, exchange):
|
||||
"""Iterator giving all routes for an exchange."""
|
||||
|
@ -62,21 +64,37 @@ class Table(Domain):
|
|||
id = gen_unique_id()
|
||||
return self.new_item(id), id
|
||||
|
||||
def queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
if queue not in self._already_bound:
|
||||
binding, id = self.create_binding(queue)
|
||||
binding.update(exchange=exchange,
|
||||
routing_key=routing_key or "",
|
||||
pattern=pattern or "",
|
||||
queue=queue or "",
|
||||
id=id)
|
||||
binding.save()
|
||||
self._already_bound.add(queue)
|
||||
|
||||
def queue_delete(self, queue):
|
||||
"""delete queue by name."""
|
||||
qid = self._get_queue_id(queue)
|
||||
if qid:
|
||||
self.delete_item(qid)
|
||||
self._already_bound.discard(queue)
|
||||
item = self._get_queue_item(queue)
|
||||
if item:
|
||||
self.delete_item(item)
|
||||
|
||||
def exchange_delete(self, exchange):
|
||||
"""Delete all routes for `exchange`."""
|
||||
self._already_bound.discard(queue)
|
||||
for item in self.routes_for(exchange):
|
||||
self.delete_item(item["id"])
|
||||
|
||||
def get_item(self, item_name, consistent_read=True):
|
||||
def get_item(self, item_name):
|
||||
"""Uses `consistent_read` by default."""
|
||||
# Domain is an old-style class, can't use super().
|
||||
return Domain.get_item(self, item_name, consistent_read)
|
||||
for consistent_read in (False, True):
|
||||
item = Domain.get_item(self, item_name, consistent_read)
|
||||
if item:
|
||||
return item
|
||||
|
||||
def select(self, query='', next_token=None, consistent_read=True,
|
||||
max_items=None):
|
||||
|
@ -85,9 +103,17 @@ class Table(Domain):
|
|||
return Domain.select(self, query, next_token,
|
||||
consistent_read, max_items)
|
||||
|
||||
def _try_first(self, query='', **kwargs):
|
||||
for c in (False, True):
|
||||
for item in self.select(query, consistent_read=c, **kwargs):
|
||||
return item
|
||||
|
||||
def _get_queue_item(self, queue):
|
||||
return self._try_first("""queue = '%s' limit 1""" % queue)
|
||||
|
||||
def _get_queue_id(self, queue):
|
||||
for item in self.select("""queue = '%s' limit 1""" % queue,
|
||||
max_items=1):
|
||||
item = self._get_queue_item(queue)
|
||||
if item:
|
||||
return item["id"]
|
||||
|
||||
|
||||
|
@ -98,6 +124,20 @@ class Channel(virtual.Channel):
|
|||
domain_format = "kombu%(vhost)s"
|
||||
_sdb = None
|
||||
_sqs = None
|
||||
_queue_cache = {}
|
||||
_noack_queues = set()
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
return super(Channel, self).basic_consume(queue, no_ack,
|
||||
*args, **kwargs)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super(Channel, self).basic_cancel(consumer_tag)
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
|
||||
"""Format AMQP queue name into a legal SQS queue name."""
|
||||
|
@ -105,25 +145,23 @@ class Channel(virtual.Channel):
|
|||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
"""Ensures a queue exists in SQS."""
|
||||
return self.sqs.create_queue(self.entity_name(queue),
|
||||
self.visibility_timeout)
|
||||
try:
|
||||
return self._queue_cache[queue]
|
||||
except KeyError:
|
||||
q = self._queue_cache[queue] = self.sqs.create_queue(
|
||||
self.entity_name(queue),
|
||||
self.visibility_timeout)
|
||||
return q
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
def _queue_bind(self, *args):
|
||||
"""Bind ``queue`` to ``exchange`` with routing key.
|
||||
|
||||
Route will be stored in SDB if so enabled.
|
||||
|
||||
"""
|
||||
if not self.supports_fanout:
|
||||
return
|
||||
if self.supports_fanout:
|
||||
self.table.queue_bind(*args)
|
||||
|
||||
binding, id = self.table.create_binding(queue)
|
||||
binding.update(exchange=exchange,
|
||||
routing_key=routing_key or "",
|
||||
pattern=pattern or "",
|
||||
queue=queue or "",
|
||||
id=id)
|
||||
binding.save()
|
||||
|
||||
def get_table(self, exchange):
|
||||
"""Get routing table.
|
||||
|
@ -138,6 +176,7 @@ class Channel(virtual.Channel):
|
|||
|
||||
def _delete(self, queue):
|
||||
"""delete queue by name."""
|
||||
self._queue_cache.pop(queue, None)
|
||||
self.table.queue_delete(queue)
|
||||
super(Channel, self)._delete(queue)
|
||||
|
||||
|
@ -172,14 +211,22 @@ class Channel(virtual.Channel):
|
|||
if rs:
|
||||
m = rs[0]
|
||||
payload = deserialize(rs[0].get_body())
|
||||
payload["properties"]["delivery_info"].update({
|
||||
"sqs_message": m, "sqs_queue": q})
|
||||
if queue in self._noack_queues:
|
||||
q.delete_message(m)
|
||||
else:
|
||||
payload["properties"]["delivery_info"].update({
|
||||
"sqs_message": m, "sqs_queue": q, })
|
||||
return payload
|
||||
raise Empty()
|
||||
|
||||
def basic_ack(self, delivery_tag):
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
delivery_info["sqs_queue"].delete_message(delivery_info["sqs_message"])
|
||||
try:
|
||||
queue = delivery_info["sqs_queue"]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
queue.delete_message(delivery_info["sqs_message"])
|
||||
super(Channel, self).basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue):
|
||||
|
@ -268,7 +315,7 @@ class Channel(virtual.Channel):
|
|||
class Transport(virtual.Transport):
|
||||
Channel = Channel
|
||||
|
||||
interval = 1
|
||||
polling_interval = 1
|
||||
default_port = None
|
||||
connection_errors = (exception.SQSError, socket.error)
|
||||
channel_errors = (exception.SQSDecodeError, )
|
||||
|
|
|
@ -30,6 +30,11 @@ class StdChannel(object):
|
|||
raise NotImplementedError("%r does not implement list_bindings" % (
|
||||
self.__class__, ))
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
"""reply queue semantics: can be used to delete the queue
|
||||
after transient reply message received."""
|
||||
pass
|
||||
|
||||
|
||||
class Message(object):
|
||||
"""Base class for received messages."""
|
||||
|
|
|
@ -359,6 +359,9 @@ class Channel(AbstractChannel, base.StdChannel):
|
|||
self._delete(queue)
|
||||
self.state.bindings.pop(queue, None)
|
||||
|
||||
def after_reply_message_received(self, queue):
|
||||
self.queue_delete(queue)
|
||||
|
||||
def queue_bind(self, queue, exchange, routing_key, arguments=None,
|
||||
**kwargs):
|
||||
"""Bind `queue` to `exchange` with `routing key`."""
|
||||
|
@ -472,9 +475,8 @@ class Channel(AbstractChannel, base.StdChannel):
|
|||
|
||||
"""
|
||||
try:
|
||||
table = self.get_table(exchange)
|
||||
return self.typeof(exchange).lookup(table, exchange,
|
||||
routing_key, default)
|
||||
return self.typeof(exchange).lookup(self.get_table(exchange),
|
||||
exchange, routing_key, default)
|
||||
except KeyError:
|
||||
self._new_queue(default)
|
||||
return [default]
|
||||
|
@ -622,7 +624,7 @@ class Transport(base.Transport):
|
|||
|
||||
def create_channel(self, connection):
|
||||
try:
|
||||
channel = self._avail_channels.pop()
|
||||
return self._avail_channels.pop()
|
||||
except IndexError:
|
||||
channel = self.Channel(connection)
|
||||
self.channels.append(channel)
|
||||
|
@ -638,6 +640,9 @@ class Transport(base.Transport):
|
|||
channel.connection = None
|
||||
|
||||
def establish_connection(self):
|
||||
# creates channel to verify connection.
|
||||
# this channel is then used as the next requested channel.
|
||||
# (returned by ``create_channel``).
|
||||
self._avail_channels.append(self.create_channel(self))
|
||||
return self # for drain events
|
||||
|
||||
|
|
|
@ -46,4 +46,5 @@ class FairCycle(object):
|
|||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<FairCycle: %r>" % (self.resources, )
|
||||
return "<FairCycle: %r/%r %r>" % (self.pos, len(self.resources),
|
||||
self.resources, )
|
||||
|
|
Loading…
Reference in New Issue