From 128ea635e8d8eeeb9b3e71ea4fbe962e287545e8 Mon Sep 17 00:00:00 2001 From: Ask Solem Date: Wed, 21 Jul 2010 16:07:45 +0200 Subject: [PATCH] Redis backend (basic testing with celery works) --- kombu/backends/__init__.py | 2 + kombu/backends/emulation.py | 234 +++++++++++++++++++++++------------- kombu/backends/memory.py | 9 +- kombu/backends/pyredis.py | 85 +++++++++++++ kombu/utils.py | 214 +++++++++++++++++++++++++++++++++ 5 files changed, 461 insertions(+), 83 deletions(-) create mode 100644 kombu/backends/pyredis.py diff --git a/kombu/backends/__init__.py b/kombu/backends/__init__.py index 4245873c..f522574f 100644 --- a/kombu/backends/__init__.py +++ b/kombu/backends/__init__.py @@ -9,6 +9,8 @@ BACKEND_ALIASES = { "pika": "kombu.backends.pypika.AsyncoreBackend", "syncpika": "kombu.backends.pypika.SyncBackend", "memory": "kombu.backends.memory.MemoryBackend", + "redis": "kombu.backends.pyredis.RedisBackend", + "nbredis": "kombu.backends.pyredis.NBRedisBackend", } _backend_cache = {} diff --git a/kombu/backends/emulation.py b/kombu/backends/emulation.py index 9ef26483..90595ab5 100644 --- a/kombu/backends/emulation.py +++ b/kombu/backends/emulation.py @@ -1,102 +1,106 @@ -from kombu.backends.base import BaseBackend, BaseMessage -from anyjson import deserialize, serialize -from itertools import count -from collections import OrderedDict -import sys -import time import atexit +import pickle +import sys +import tempfile +from time import sleep +from itertools import count, cycle from Queue import Empty as QueueEmpty -from itertools import cycle + +from kombu.backends.base import BaseBackend, BaseMessage +from kombu.utils import OrderedDict -class QueueSet(object): - """A set of queues that operates as one.""" +class Consume(object): + """Consume from a set of resources, where each resource gets + an equal chance to be consumed from.""" - def __init__(self, backend, queues): - self.backend = backend - self.queues = queues + def __init__(self, fun, resources, predicate=QueueEmpty): + self.fun = fun + self.resources = resources + self.predicate = predicate # an infinite cycle through all the queues. - self.cycle = cycle(self.queues) + self.cycle = cycle(self.resources) - # A set of all the queue names, so we can match when we've + # A set of all the names, so we can match when we've # tried all of them. - self.all = frozenset(self.queues) + self.all = frozenset(self.resources) def get(self): - """Get the next message avaiable in the queue. - - :returns: The message and the name of the queue it came from as - a tuple. - :raises Queue.Empty: If there are no more items in any of the queues. - - """ - - # A set of queues we've already tried. + # What we've already tried. tried = set() while True: - # Get the next queue in the cycle, and try to get an item off it. - queue = self.cycle.next() + # Get the next resource in the cycle, + # and try to get an item off it. + resource = self.cycle.next() try: - item = self.backend._get(queue) - except QueueEmpty: - # raises Empty when we've tried all of them. - tried.add(queue) + return self.fun(resource), resource + except self.predicate: + tried.add(resource) if tried == self.all: raise - else: - return item, queue - - def __repr__(self): - return "" % repr(self.queue_names) class QualityOfService(object): - def __init__(self, resource, prefetch_count=None, interval=None, + def __init__(self, channel, prefetch_count=None, interval=None, do_restore=True): - self.resource = resource + self.channel = channel self.prefetch_count = prefetch_count self.interval = interval self._delivered = OrderedDict() self.do_restore = do_restore self._restored_once = False - atexit.register(self.restore_unacked_once) + if self.do_restore: + atexit.register(self.restore_unacked_once) def can_consume(self): - return len(self._delivered) > self.prefetch_count + if not self.prefetch_count: + return True + return len(self._delivered) < self.prefetch_count - def append(self, message, queue_name, delivery_tag): - self._delivered[delivery_tag] = message, queue_name + def append(self, message, delivery_tag): + self._delivered[delivery_tag] = message def ack(self, delivery_tag): self._delivered.pop(delivery_tag, None) def restore_unacked(self): - for message, queue_name in self._delivered.items(): - self.resource._put(queue_name, message) - self._delivered = SortedDict() + for message in self._delivered.items(): + self.channel._restore(message) + self._delivered.clear() def requeue(self, delivery_tag): try: - message, queue_name = self._delivered.pop(delivery_tag) + message = self._delivered.pop(delivery_tag) except KeyError: pass - self.resource.put(queue_name, message) + self.channel._restore(message) def restore_unacked_once(self): - if self.do_restore: - if not self._restored_once: - if self._delivered: - sys.stderr.write( - "Restoring unacknowledged messages: %s\n" % ( - self._delivered)) + if self.do_restore and not self._restored_once: + if self._delivered: + sys.stderr.write( + "Restoring unacknowledged messages: %s\n" % ( + self._delivered)) + try: self.restore_unacked() - if self._delivered: - sys.stderr.write("UNRESTORED MESSAGES: %s\n" % ( - self._delivered)) + except: + pass + if self._delivered: + sys.stderr.write("UNABLE TO RESTORE %s MESSAGES\n" % ( + len(self._delivered))) + persist = tempfile.mktemp() + sys.stderr.write( + "PERSISTING UNRESTORED MESSAGES TO FILE: %s\n" % persist) + fh = open(persist, "w") + try: + pickle.dump(self._delivered, fh, protocol=0) + finally: + fh.flush() + fh.close() class Message(BaseMessage): @@ -111,7 +115,6 @@ class Message(BaseMessage): kwargs["headers"] = payload.get("headers") kwargs["properties"] = properties kwargs["delivery_info"] = properties.get("delivery_info") - self.destination = payload.get("destination") super(Message, self).__init__(channel, **kwargs) @@ -121,7 +124,6 @@ class Message(BaseMessage): _exchanges = {} -_queues = {} _consumers = {} _callbacks = {} @@ -149,8 +151,41 @@ class Channel(object): def _purge(self, queue): raise NotImplementedError("Emulations must implement _purge") + def _size(self, queue): + return 0 + + def _delete(self, queue): + self._purge(queue) + def _new_queue(self, queue): - raise NotImplementedError("Emulations must implement _new_queue") + pass + + def _lookup(self, exchange, routing_key, default="ae.undeliver"): + try: + return _exchanges[exchange]["table"][routing_key] + except KeyError: + self._new_queue(default) + return default + + def _restore(self, message): + delivery_info = message.delivery_info + self._put(self._lookup(delivery_info["exchange"], + delivery_info["routing_key"]), + message) + + def _poll(self, resource): + while True: + if self.qos_manager.can_consume(): + try: + return resource.get() + except QueueEmpty: + pass + + def drain_events(self, timeout=None): + if self.qos_manager.can_consume(): + queues = [_consumers[tag] for tag in self._consumers] + return Consume(self._get, queues, QueueEmpty).get() + raise QueueEmpty() def exchange_declare(self, exchange, type="direct", durable=False, auto_delete=False, arguments=None): @@ -161,9 +196,19 @@ class Channel(object): "arguments": arguments or {}, "table": {}} + def exchange_delete(self, exchange, if_unused=False): + for rkey, queue in _exchanges[exchange]["table"].items(): + self._purge(queue) + _exchanges.pop(exchange, None) + def queue_declare(self, queue, **kwargs): - if queue not in _queues: - _queues[queue] = self._new_queue(queue, **kwargs) + self._new_queue(queue, **kwargs) + return queue, self._size(queue), 0 + + def queue_delete(self, queue, if_unusued=False, if_empty=False): + if if_empty and self._size(queue): + return + self._delete(queue) def queue_bind(self, queue, exchange, routing_key, arguments=None): table = _exchanges[exchange].setdefault("table", {}) @@ -187,6 +232,11 @@ class Channel(object): def basic_ack(self, delivery_tag): self.qos_manager.ack(delivery_tag) + def basic_recover(self, requeue=False): + if requeue: + return self.qos_manager.restore_unacked() + raise NotImplementedError("Does not support recover(requeue=False)") + def basic_reject(self, delivery_tag, requeue=False): if requeue: self.qos_manager.requeue(delivery_tag) @@ -198,11 +248,10 @@ class Channel(object): self._consumers.add(consumer_tag) def basic_publish(self, message, exchange, routing_key, **kwargs): - message["destination"] = exchange + message["properties"]["delivery_info"]["exchange"] = exchange + message["properties"]["delivery_info"]["routing_key"] = routing_key message["properties"]["delivery_tag"] = self._next_delivery_tag() - table = _exchanges[exchange]["table"] - if routing_key in table: - self._put(table[routing_key], message) + self._put(self._lookup(exchange, routing_key), message) def basic_cancel(self, consumer_tag): queue = _consumers.pop(consumer_tag, None) @@ -211,13 +260,14 @@ class Channel(object): def message_to_python(self, raw_message): message = self.Message(self, payload=raw_message) - self.qos_manager.append(message, message.destination, - message.delivery_tag) + self.qos_manager.append(message, message.delivery_tag) return message def prepare_message(self, message_data, priority=None, content_type=None, content_encoding=None, headers=None, properties=None): + properties = properties or {} + properties.setdefault("delivery_info", {}) return {"body": message_data, "priority": priority or 0, "content-encoding": content_encoding, @@ -238,41 +288,61 @@ class Channel(object): return self._qos_manager def close(self): - map(self.basic_cancel, self._consumers) + map(self.basic_cancel, list(self._consumers)) + self.connection.close_channel(self) class EmulationBase(BaseBackend): Channel = Channel - QueueSet = QueueSet + Consume = Consume + interval = 1 default_port = None def __init__(self, connection, **kwargs): self.connection = connection + self._channels = set() def create_channel(self, connection): - return self.Channel(connection) + channel = self.Channel(connection) + self._channels.add(channel) + return channel + + def close_channel(self, channel): + try: + self._channels.remove(channel) + except KeyError: + pass def establish_connection(self): return self # for drain events def close_connection(self, connection): - pass + while self._channels: + try: + channel = self._channels.pop() + except KeyError: + pass + else: + channel.close() - def _poll(self, resource): - while True: - if self.qos_manager.can_consume(): - try: - return resource.get() - except QueueEmpty: - pass - time.sleep(self.interval) + def _drain_channel(self, channel): + return channel.drain_events() def drain_events(self, timeout=None): - queueset = self.QueueSet(self._consumers.values()) - payload, queue = self._poll(queueset) + consumer = Consume(self._drain_channel, self._channels, QueueEmpty) + while True: + try: + item, channel = consumer.get() + break + except QueueEmpty: + sleep(self.interval) + + message, queue = item if not queue or queue not in _callbacks: - return + raise KeyError( + "Received message for queue '%s' without consumers: %s" % ( + queue, message)) - _callbacks[queue](payload) + _callbacks[queue](message) diff --git a/kombu/backends/memory.py b/kombu/backends/memory.py index bf9e3d18..a7b21559 100644 --- a/kombu/backends/memory.py +++ b/kombu/backends/memory.py @@ -8,7 +8,8 @@ class MemoryChannel(emulation.Channel): do_restore = False def _new_queue(self, queue, **kwargs): - self.queues[queue] = Queue() + if queue not in self.queues: + self.queues[queue] = Queue() def _get(self, queue): return self.queues[queue].get(block=False) @@ -16,6 +17,12 @@ class MemoryChannel(emulation.Channel): def _put(self, queue, message): self.queues[queue].put(message) + def _size(self, queue): + return self.queues[queue].qsize() + + def _delete(self, queue): + self.queues.pop(queue, None) + def _purge(self, queue): size = self.queues[queue].qsize() self.queues[queue].queue.clear() diff --git a/kombu/backends/pyredis.py b/kombu/backends/pyredis.py new file mode 100644 index 00000000..e75928b4 --- /dev/null +++ b/kombu/backends/pyredis.py @@ -0,0 +1,85 @@ +from Queue import Empty + +from anyjson import serialize, deserialize +from redis import Redis +from redis import exceptions + +from kombu.backends import emulation + +DEFAULT_PORT = 6379 +DEFAULT_DB = 0 + + +class RedisChannel(emulation.Channel): + queues = {} + do_restore = False + _client = None + + def _new_queue(self, queue, **kwargs): + pass + + def _get(self, queue): + item = self.client.rpop(queue) + if item: + return deserialize(item) + raise Empty() + + def _size(self, queue): + return self.client.llen(queue) + + def _get_many(self, queue, timeout=None): + dest__item = self.client.brpop(queues, timeout=timeout) + if dest__item: + dest, item = dest__item + return deserialize(dest), item + raise Empty() + + def _put(self, queue, message): + self.client.lpush(queue, serialize(message)) + + def _purge(self, queue): + size = self.client.llen(queue) + self.client.delete(queue) + return size + + def close(self): + super(RedisChannel, self).close() + try: + self.client.bgsave() + except exceptions.ResponseError: + pass + + def _open(self): + conninfo = self.connection.connection + database = conninfo.virtual_host + if not isinstance(database, int): + if not database or database == "/": + database = DEFAULT_DB + elif database.startswith("/"): + database = database[1:] + try: + database = int(database) + except ValueError: + raise ValueError( + "Database name must be int between 0 and limit - 1") + + return Redis(host=conninfo.hostname, + port=conninfo.port or DEFAULT_PORT, + db=database, + password=conninfo.password) + + @property + def client(self): + if self._client is None: + self._client = self._open() + return self._client + + +class RedisBackend(emulation.EmulationBase): + Channel = RedisChannel + + default_port = DEFAULT_PORT + connection_errors = (exceptions.ConnectionError, ) + channel_errors = (exceptions.ConnectionError, + exceptions.InvalidResponse, + exceptions.ResponseError) diff --git a/kombu/utils.py b/kombu/utils.py index 2feebaab..d0770b9f 100644 --- a/kombu/utils.py +++ b/kombu/utils.py @@ -6,6 +6,8 @@ def maybe_list(v): return [v] +############## str.partition/str.rpartition ################################# + def _compat_rl_partition(S, sep, direction=None, reverse=False): items = direction(sep, 1) if len(items) == 1: @@ -50,3 +52,215 @@ def rpartition(S, sep): return S.rpartition(sep) else: # Python <= 2.4: return _compat_rpartition(S, sep) + + +############## collections.OrderedDict ####################################### + +import weakref +try: + from collections import MutableMapping +except ImportError: + from UserDict import DictMixin as MutableMapping +from itertools import imap as _imap +from operator import eq as _eq + + +class _Link(object): + """Doubly linked list.""" + __slots__ = 'prev', 'next', 'key', '__weakref__' + + +class OrderedDict(dict, MutableMapping): + """Dictionary that remembers insertion order""" + # An inherited dict maps keys to values. + # The inherited dict provides __getitem__, __len__, __contains__, and get. + # The remaining methods are order-aware. + # Big-O running times for all methods are the same as for regular + # dictionaries. + + # The internal self.__map dictionary maps keys to links in a doubly + # linked list. + # The circular doubly linked list starts and ends with a sentinel element. + # The sentinel element never gets deleted (this simplifies the algorithm). + # The prev/next links are weakref proxies (to prevent circular + # references). + # Individual links are kept alive by the hard reference in self.__map. + # Those hard references disappear when a key is deleted from + # an OrderedDict. + + __marker = object() + + def __init__(self, *args, **kwds): + """Initialize an ordered dictionary. + + Signature is the same as for regular dictionaries, but keyword + arguments are not recommended because their insertion order is + arbitrary. + + """ + if len(args) > 1: + raise TypeError("expected at most 1 arguments, got %d" % ( + len(args))) + try: + self.__root + except AttributeError: + # sentinel node for the doubly linked list + self.__root = root = _Link() + root.prev = root.next = root + self.__map = {} + self.update(*args, **kwds) + + def clear(self): + "od.clear() -> None. Remove all items from od." + root = self.__root + root.prev = root.next = root + self.__map.clear() + dict.clear(self) + + def __setitem__(self, key, value): + "od.__setitem__(i, y) <==> od[i]=y" + # Setting a new item creates a new link which goes at the end of the + # linked list, and the inherited dictionary is updated with the new + # key/value pair. + if key not in self: + self.__map[key] = link = _Link() + root = self.__root + last = root.prev + link.prev, link.next, link.key = last, root, key + last.next = root.prev = weakref.proxy(link) + dict.__setitem__(self, key, value) + + def __delitem__(self, key): + """od.__delitem__(y) <==> del od[y]""" + # Deleting an existing item uses self.__map to find the + # link which is then removed by updating the links in the + # predecessor and successor nodes. + dict.__delitem__(self, key) + link = self.__map.pop(key) + link.prev.next = link.next + link.next.prev = link.prev + + def __iter__(self): + """od.__iter__() <==> iter(od)""" + # Traverse the linked list in order. + root = self.__root + curr = root.next + while curr is not root: + yield curr.key + curr = curr.next + + def __reversed__(self): + """od.__reversed__() <==> reversed(od)""" + # Traverse the linked list in reverse order. + root = self.__root + curr = root.prev + while curr is not root: + yield curr.key + curr = curr.prev + + def __reduce__(self): + """Return state information for pickling""" + items = [[k, self[k]] for k in self] + tmp = self.__map, self.__root + del(self.__map, self.__root) + inst_dict = vars(self).copy() + self.__map, self.__root = tmp + if inst_dict: + return (self.__class__, (items,), inst_dict) + return self.__class__, (items,) + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def update(self, other=(), **kwds): + if isinstance(other, dict): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def values(self): + return [self[key] for key in self] + + def items(self): + return [(key, self[key]) for key in self] + + def itervalues(self): + for key in self: + yield self[key] + + def iteritems(self): + for key in self: + yield (key, self[key]) + + def iterkeys(self): + return iter(self) + + def keys(self): + return list(self) + + def popitem(self, last=True): + """od.popitem() -> (k, v) + + Return and remove a (key, value) pair. + Pairs are returned in LIFO order if last is true or FIFO + order if false. + + """ + if not self: + raise KeyError('dictionary is empty') + key = (last and reversed(self) or iter(self)).next() + value = self.pop(key) + return key, value + + def __repr__(self): + "od.__repr__() <==> repr(od)" + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, self.items()) + + def copy(self): + "od.copy() -> a shallow copy of od" + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + """OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S + and values equal to v (which defaults to None).""" + d = cls() + for key in iterable: + d[key] = value + return d + + def __eq__(self, other): + """od.__eq__(y) <==> od==y. Comparison to another OD is + order-sensitive while comparison to a regular mapping + is order-insensitive.""" + if isinstance(other, OrderedDict): + return len(self) == len(other) and \ + all(_imap(_eq, self.iteritems(), other.iteritems())) + return dict.__eq__(self, other) + + def __ne__(self, other): + return not (self == other)