diff --git a/kombu/transport/azurestoragequeues.py b/kombu/transport/azurestoragequeues.py index ed651d9c..b608feac 100644 --- a/kombu/transport/azurestoragequeues.py +++ b/kombu/transport/azurestoragequeues.py @@ -19,7 +19,8 @@ Connection string has the following format: .. code-block:: - azurestoragequeues://:STORAGE_ACCOUNT_ACCESS kEY@STORAGE_ACCOUNT_NAME + azurestoragequeues://STORAGE_ACCOUNT_ACCESS_KEY@STORAGE_ACCOUNT_URL + azurestoragequeues://SAS_TOKEN@STORAGE_ACCOUNT_URL Note that if the access key for the storage account contains a slash, it will have to be regenerated before it can be used in the connection URL. @@ -35,6 +36,8 @@ from __future__ import annotations import string from queue import Empty +from azure.core.exceptions import ResourceExistsError + from kombu.utils.encoding import safe_str from kombu.utils.json import dumps, loads from kombu.utils.objects import cached_property @@ -42,9 +45,9 @@ from kombu.utils.objects import cached_property from . import virtual try: - from azure.storage.queue import QueueService + from azure.storage.queue import QueueServiceClient except ImportError: # pragma: no cover - QueueService = None + QueueServiceClient = None # Azure storage queues allow only alphanumeric and dashes # so, replace everything with a dash @@ -63,14 +66,18 @@ class Channel(virtual.Channel): _noack_queues = set() def __init__(self, *args, **kwargs): - if QueueService is None: + if QueueServiceClient is None: raise ImportError('Azure Storage Queues transport requires the ' 'azure-storage-queue library') super().__init__(*args, **kwargs) - for queue_name in self.queue_service.list_queues(): - self._queue_name_cache[queue_name] = queue_name + self._credential, self._url = Transport.parse_uri( + self.conninfo.hostname + ) + + for queue in self.queue_service.list_queues(): + self._queue_name_cache[queue['name']] = queue def basic_consume(self, queue, no_ack, *args, **kwargs): if no_ack: @@ -87,61 +94,64 @@ class Channel(virtual.Channel): """Ensure a queue exists.""" queue = self.entity_name(self.queue_name_prefix + queue) try: - return self._queue_name_cache[queue] + q = self._queue_service.get_queue_client( + queue=self._queue_name_cache[queue] + ) except KeyError: - self.queue_service.create_queue(queue, fail_on_exist=False) - q = self._queue_name_cache[queue] = queue - return q + try: + q = self.queue_service.create_queue(queue) + except ResourceExistsError: + q = self._queue_service.get_queue_client(queue=queue) + + self._queue_name_cache[queue] = q + return q def _delete(self, queue, *args, **kwargs): """Delete queue by name.""" queue_name = self.entity_name(queue) self._queue_name_cache.pop(queue_name, None) self.queue_service.delete_queue(queue_name) - super()._delete(queue_name) def _put(self, queue, message, **kwargs): """Put message onto queue.""" q = self._ensure_queue(queue) encoded_message = dumps(message) - self.queue_service.put_message(q, encoded_message) + q.send_message(encoded_message) def _get(self, queue, timeout=None): """Try to retrieve a single message off ``queue``.""" q = self._ensure_queue(queue) - messages = self.queue_service.get_messages(q, num_messages=1, - timeout=timeout) - if not messages: + messages = q.receive_messages(messages_per_page=1, timeout=timeout) + try: + message = next(messages) + except StopIteration: raise Empty() - message = messages[0] - raw_content = self.queue_service.decode_function(message.content) - content = loads(raw_content) + content = loads(message.content) - self.queue_service.delete_message(q, message.id, message.pop_receipt) + q.delete_message(message=message) return content def _size(self, queue): """Return the number of messages in a queue.""" q = self._ensure_queue(queue) - metadata = self.queue_service.get_queue_metadata(q) - return metadata.approximate_message_count + return q.get_queue_properties().approximate_message_count def _purge(self, queue): """Delete all current messages in a queue.""" q = self._ensure_queue(queue) - n = self._size(q) - self.queue_service.clear_messages(q) + n = self._size(q.queue_name) + q.clear_messages() return n @property def queue_service(self): if self._queue_service is None: - self._queue_service = QueueService( - account_name=self.conninfo.hostname, - account_key=self.conninfo.password) + self._queue_service = QueueServiceClient( + account_url=self._url, credential=self._credential + ) return self._queue_service @@ -165,3 +175,37 @@ class Transport(virtual.Transport): polling_interval = 1 default_port = None + can_parse_url = True + + @staticmethod + def parse_uri(uri: str) -> tuple[str, str]: + # URL like: + # azurestoragequeues://STORAGE_ACCOUNT_ACCESS_KEY@STORAGE_ACCOUNT_URL + # azurestoragequeues://SAS_TOKEN@STORAGE_ACCOUNT_URL + + # urllib parse does not work as the sas key could contain a slash + # e.g.: azurestoragequeues://some/key@someurl + + try: + # > 'some/key@url' + uri = uri.replace('azurestoragequeues://', '') + # > 'some/key', 'url' + credential, url = uri.rsplit('@', 1) + + # Validate parameters + assert all([credential, url]) + except Exception: + raise ValueError( + 'Need a URI like ' + 'azurestoragequeues://{SAS or access key}@{URL}' + ) + + return credential, url + + @classmethod + def as_uri(cls, uri: str, include_password=False, mask='**') -> str: + credential, url = cls.parse_uri(uri) + return 'azurestoragequeues://{}@{}'.format( + credential if include_password else mask, + url + ) diff --git a/requirements/extras/azurestoragequeues.txt b/requirements/extras/azurestoragequeues.txt index 2424ee7e..09e3ddc4 100644 --- a/requirements/extras/azurestoragequeues.txt +++ b/requirements/extras/azurestoragequeues.txt @@ -1 +1 @@ -azure-storage-queue +azure-storage-queue>=12.2.0 diff --git a/t/unit/transport/test_azurestoragequeues.py b/t/unit/transport/test_azurestoragequeues.py new file mode 100644 index 00000000..d5568cbb --- /dev/null +++ b/t/unit/transport/test_azurestoragequeues.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from kombu import Connection + +pytest.importorskip('azure.storage.queue') +from kombu.transport import azurestoragequeues # noqa + +URL_NOCREDS = 'azurestoragequeues://' +URL_CREDS = 'azurestoragequeues://sas/key%@https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa + + +def test_queue_service_nocredentials(): + conn = Connection(URL_NOCREDS, transport=azurestoragequeues.Transport) + with pytest.raises( + ValueError, + match='Need a URI like azurestoragequeues://{SAS or access key}@{URL}' + ): + conn.channel() + + +def test_queue_service(): + # Test gettings queue service without credentials + conn = Connection(URL_CREDS, transport=azurestoragequeues.Transport) + with patch('kombu.transport.azurestoragequeues.QueueServiceClient'): + channel = conn.channel() + + # Check the SAS token "sas/key%" has been parsed from the url correctly + assert channel._credential == 'sas/key%' + assert channel._url == 'https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa