SQS improvements

This commit is contained in:
Ask Solem 2011-06-15 14:26:48 +01:00
parent 69b36c91a5
commit af4d2b0cfd
5 changed files with 87 additions and 28 deletions

View File

@ -253,6 +253,7 @@ class Mailbox(object):
self.connection.drain_events(timeout=timeout)
except socket.timeout:
break
chan.after_reply_message_received(queue.name)
return responses
finally:
channel or chan.close()

View File

@ -39,6 +39,8 @@ CHARS_REPLACE_TABLE = string.maketrans(CHARS_REPLACE + '.',
class Table(Domain):
"""Amazon SimpleDB domain describing the message routing table."""
# caches queues already bound, so we don't have to declare them again.
_already_bound = set()
def routes_for(self, exchange):
"""Iterator giving all routes for an exchange."""
@ -62,21 +64,37 @@ class Table(Domain):
id = gen_unique_id()
return self.new_item(id), id
def queue_bind(self, exchange, routing_key, pattern, queue):
if queue not in self._already_bound:
binding, id = self.create_binding(queue)
binding.update(exchange=exchange,
routing_key=routing_key or "",
pattern=pattern or "",
queue=queue or "",
id=id)
binding.save()
self._already_bound.add(queue)
def queue_delete(self, queue):
"""delete queue by name."""
qid = self._get_queue_id(queue)
if qid:
self.delete_item(qid)
self._already_bound.discard(queue)
item = self._get_queue_item(queue)
if item:
self.delete_item(item)
def exchange_delete(self, exchange):
"""Delete all routes for `exchange`."""
self._already_bound.discard(queue)
for item in self.routes_for(exchange):
self.delete_item(item["id"])
def get_item(self, item_name, consistent_read=True):
def get_item(self, item_name):
"""Uses `consistent_read` by default."""
# Domain is an old-style class, can't use super().
return Domain.get_item(self, item_name, consistent_read)
for consistent_read in (False, True):
item = Domain.get_item(self, item_name, consistent_read)
if item:
return item
def select(self, query='', next_token=None, consistent_read=True,
max_items=None):
@ -85,9 +103,17 @@ class Table(Domain):
return Domain.select(self, query, next_token,
consistent_read, max_items)
def _try_first(self, query='', **kwargs):
for c in (False, True):
for item in self.select(query, consistent_read=c, **kwargs):
return item
def _get_queue_item(self, queue):
return self._try_first("""queue = '%s' limit 1""" % queue)
def _get_queue_id(self, queue):
for item in self.select("""queue = '%s' limit 1""" % queue,
max_items=1):
item = self._get_queue_item(queue)
if item:
return item["id"]
@ -98,6 +124,20 @@ class Channel(virtual.Channel):
domain_format = "kombu%(vhost)s"
_sdb = None
_sqs = None
_queue_cache = {}
_noack_queues = set()
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super(Channel, self).basic_consume(queue, no_ack,
*args, **kwargs)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super(Channel, self).basic_cancel(consumer_tag)
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
"""Format AMQP queue name into a legal SQS queue name."""
@ -105,25 +145,23 @@ class Channel(virtual.Channel):
def _new_queue(self, queue, **kwargs):
"""Ensures a queue exists in SQS."""
return self.sqs.create_queue(self.entity_name(queue),
self.visibility_timeout)
try:
return self._queue_cache[queue]
except KeyError:
q = self._queue_cache[queue] = self.sqs.create_queue(
self.entity_name(queue),
self.visibility_timeout)
return q
def _queue_bind(self, exchange, routing_key, pattern, queue):
def _queue_bind(self, *args):
"""Bind ``queue`` to ``exchange`` with routing key.
Route will be stored in SDB if so enabled.
"""
if not self.supports_fanout:
return
if self.supports_fanout:
self.table.queue_bind(*args)
binding, id = self.table.create_binding(queue)
binding.update(exchange=exchange,
routing_key=routing_key or "",
pattern=pattern or "",
queue=queue or "",
id=id)
binding.save()
def get_table(self, exchange):
"""Get routing table.
@ -138,6 +176,7 @@ class Channel(virtual.Channel):
def _delete(self, queue):
"""delete queue by name."""
self._queue_cache.pop(queue, None)
self.table.queue_delete(queue)
super(Channel, self)._delete(queue)
@ -172,14 +211,22 @@ class Channel(virtual.Channel):
if rs:
m = rs[0]
payload = deserialize(rs[0].get_body())
payload["properties"]["delivery_info"].update({
"sqs_message": m, "sqs_queue": q})
if queue in self._noack_queues:
q.delete_message(m)
else:
payload["properties"]["delivery_info"].update({
"sqs_message": m, "sqs_queue": q, })
return payload
raise Empty()
def basic_ack(self, delivery_tag):
delivery_info = self.qos.get(delivery_tag).delivery_info
delivery_info["sqs_queue"].delete_message(delivery_info["sqs_message"])
try:
queue = delivery_info["sqs_queue"]
except KeyError:
pass
else:
queue.delete_message(delivery_info["sqs_message"])
super(Channel, self).basic_ack(delivery_tag)
def _size(self, queue):
@ -268,7 +315,7 @@ class Channel(virtual.Channel):
class Transport(virtual.Transport):
Channel = Channel
interval = 1
polling_interval = 1
default_port = None
connection_errors = (exception.SQSError, socket.error)
channel_errors = (exception.SQSDecodeError, )

View File

@ -30,6 +30,11 @@ class StdChannel(object):
raise NotImplementedError("%r does not implement list_bindings" % (
self.__class__, ))
def after_reply_message_received(self, queue):
"""reply queue semantics: can be used to delete the queue
after transient reply message received."""
pass
class Message(object):
"""Base class for received messages."""

View File

@ -359,6 +359,9 @@ class Channel(AbstractChannel, base.StdChannel):
self._delete(queue)
self.state.bindings.pop(queue, None)
def after_reply_message_received(self, queue):
self.queue_delete(queue)
def queue_bind(self, queue, exchange, routing_key, arguments=None,
**kwargs):
"""Bind `queue` to `exchange` with `routing key`."""
@ -472,9 +475,8 @@ class Channel(AbstractChannel, base.StdChannel):
"""
try:
table = self.get_table(exchange)
return self.typeof(exchange).lookup(table, exchange,
routing_key, default)
return self.typeof(exchange).lookup(self.get_table(exchange),
exchange, routing_key, default)
except KeyError:
self._new_queue(default)
return [default]
@ -622,7 +624,7 @@ class Transport(base.Transport):
def create_channel(self, connection):
try:
channel = self._avail_channels.pop()
return self._avail_channels.pop()
except IndexError:
channel = self.Channel(connection)
self.channels.append(channel)
@ -638,6 +640,9 @@ class Transport(base.Transport):
channel.connection = None
def establish_connection(self):
# creates channel to verify connection.
# this channel is then used as the next requested channel.
# (returned by ``create_channel``).
self._avail_channels.append(self.create_channel(self))
return self # for drain events

View File

@ -46,4 +46,5 @@ class FairCycle(object):
pass
def __repr__(self):
return "<FairCycle: %r>" % (self.resources, )
return "<FairCycle: %r/%r %r>" % (self.pos, len(self.resources),
self.resources, )