mirror of https://github.com/celery/kombu.git
Support for Azure Service Bus 7.0.0 (#1284)
* Started servicebus refactor * Cleaned up, handle service bus SAS token parsing
This commit is contained in:
parent
a37a05616f
commit
3d41ab1389
|
@ -4,9 +4,6 @@ Note that the Shared Access Policy used to connect to Azure Service Bus
|
|||
requires Manage, Send and Listen claims since the broker will create new
|
||||
queues and delete old queues as required.
|
||||
|
||||
Note that if the SAS key for the Service Bus account contains a slash, it will
|
||||
have to be regenerated before it can be used in the connection URL.
|
||||
|
||||
More information about Azure Service Bus:
|
||||
https://azure.microsoft.com/en-us/services/service-bus/
|
||||
|
||||
|
@ -31,31 +28,28 @@ Connection string has the following format:
|
|||
Transport Options
|
||||
=================
|
||||
|
||||
* ``visibility_timeout``
|
||||
* ``queue_name_prefix``
|
||||
* ``wait_time_seconds``
|
||||
* ``peek_lock``
|
||||
* ``queue_name_prefix`` - String prefix to prepend to queue names in a service bus namespace
|
||||
* ``wait_time_seconds`` - Number of seconds to wait to receive messages. Default ``5``
|
||||
* ``peek_lock_seconds`` - Number of seconds the message is visible for before it is requeued
|
||||
and sent to another consumer. Default ``60``
|
||||
"""
|
||||
|
||||
import string
|
||||
from queue import Empty
|
||||
from typing import Dict, Any, Optional, Union, Set
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import loads, dumps
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
import azure.core.exceptions
|
||||
import azure.servicebus.exceptions
|
||||
import isodate
|
||||
from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusReceiver, ServiceBusSender, \
|
||||
ServiceBusReceiveMode
|
||||
from azure.servicebus.management import ServiceBusAdministrationClient
|
||||
|
||||
try:
|
||||
# azure-servicebus version <= 0.21.1
|
||||
from azure.servicebus import ServiceBusService, Message, Queue
|
||||
except ImportError:
|
||||
try:
|
||||
# azure-servicebus version >= 0.50.0
|
||||
from azure.servicebus.control_client import \
|
||||
ServiceBusService, Message, Queue
|
||||
except ImportError:
|
||||
ServiceBusService = Message = Queue = None
|
||||
from . import virtual
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
CHARS_REPLACE_TABLE = {
|
||||
|
@ -63,94 +57,250 @@ CHARS_REPLACE_TABLE = {
|
|||
}
|
||||
|
||||
|
||||
class SendReceive:
|
||||
def __init__(self, receiver: Optional[ServiceBusReceiver] = None, sender: Optional[ServiceBusSender] = None):
|
||||
self.receiver = receiver # type: ServiceBusReceiver
|
||||
self.sender = sender # type: ServiceBusSender
|
||||
|
||||
def close(self) -> None:
|
||||
if self.receiver:
|
||||
self.receiver.close()
|
||||
self.receiver = None
|
||||
if self.sender:
|
||||
self.sender.close()
|
||||
self.sender = None
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""Azure Service Bus channel."""
|
||||
|
||||
default_visibility_timeout = 1800 # 30 minutes.
|
||||
default_wait_time_seconds = 5 # in seconds
|
||||
default_peek_lock = False
|
||||
default_peek_lock_seconds = 60 # in seconds (default 60, max 300)
|
||||
domain_format = 'kombu%(vhost)s'
|
||||
_queue_service = None
|
||||
_queue_cache = {}
|
||||
_queue_service = None # type: ServiceBusClient
|
||||
_queue_mgmt_service = None # type: ServiceBusAdministrationClient
|
||||
_queue_cache = {} # type: Dict[str, SendReceive]
|
||||
_noack_queues = set() # type: Set[str]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if ServiceBusService is None:
|
||||
raise ImportError('Azure Service Bus transport requires the '
|
||||
'azure-servicebus library')
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
for queue in self.queue_service.list_queues():
|
||||
self._queue_cache[queue] = queue
|
||||
self._namespace = None
|
||||
self._policy = None
|
||||
self._sas_key = None
|
||||
self._connection_string = None
|
||||
|
||||
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
|
||||
self._try_parse_connection_string()
|
||||
|
||||
self.qos.restore_at_shutdown = False
|
||||
|
||||
def _try_parse_connection_string(self) -> None:
|
||||
# URL like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
|
||||
# urllib parse does not work as the sas key could contain a slash
|
||||
# e.g. azureservicebus://rootpolicy:some/key@somenamespace
|
||||
uri = self.conninfo.hostname.replace('azureservicebus://', '') # > 'rootpolicy:some/key@somenamespace'
|
||||
policykeypair, self._namespace = uri.rsplit('@', 1) # > 'rootpolicy:some/key', 'somenamespace'
|
||||
self._policy, self._sas_key = policykeypair.split(':', 1) # > 'rootpolicy', 'some/key'
|
||||
|
||||
# Validate ASB connection string
|
||||
if not all([self._namespace, self._policy, self._sas_key]):
|
||||
raise ValueError('Need an URI like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}')
|
||||
|
||||
# Convert
|
||||
endpoint = 'sb://' + self._namespace
|
||||
if not endpoint.endswith('.net'):
|
||||
endpoint += '.servicebus.windows.net'
|
||||
|
||||
conn_dict = {
|
||||
'Endpoint': endpoint,
|
||||
'SharedAccessKeyName': self._policy,
|
||||
'SharedAccessKey': self._sas_key,
|
||||
}
|
||||
self._connection_string = ';'.join([key + '=' + value for key, value in conn_dict.items()])
|
||||
|
||||
def basic_consume(self, queue, no_ack, *args, **kwargs):
|
||||
if no_ack:
|
||||
self._noack_queues.add(queue)
|
||||
return super().basic_consume(
|
||||
queue, no_ack, *args, **kwargs
|
||||
)
|
||||
|
||||
def basic_cancel(self, consumer_tag):
|
||||
if consumer_tag in self._consumers:
|
||||
queue = self._tag_to_queue[consumer_tag]
|
||||
self._noack_queues.discard(queue)
|
||||
return super().basic_cancel(consumer_tag)
|
||||
|
||||
def _add_queue_to_cache(self,
|
||||
name: str,
|
||||
receiver: Optional[ServiceBusReceiver] = None,
|
||||
sender: Optional[ServiceBusSender] = None) -> SendReceive:
|
||||
if name in self._queue_cache:
|
||||
obj = self._queue_cache[name]
|
||||
obj.sender = obj.sender or sender
|
||||
obj.receiver = obj.receiver or receiver
|
||||
else:
|
||||
obj = SendReceive(receiver, sender)
|
||||
self._queue_cache[name] = obj
|
||||
return obj
|
||||
|
||||
def _get_asb_sender(self, queue: str) -> SendReceive:
|
||||
queue_obj = self._queue_cache.get(queue, None)
|
||||
if queue_obj is None or queue_obj.sender is None:
|
||||
sender = self.queue_service.get_queue_sender(queue)
|
||||
queue_obj = self._add_queue_to_cache(queue, sender=sender)
|
||||
return queue_obj
|
||||
|
||||
def _get_asb_receiver(self, queue: str,
|
||||
recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
|
||||
queue_cache_key: Optional[str] = None) -> SendReceive:
|
||||
cache_key = queue_cache_key or queue
|
||||
queue_obj = self._queue_cache.get(cache_key, None)
|
||||
if queue_obj is None or queue_obj.receiver is None:
|
||||
receiver = self.queue_service.get_queue_receiver(queue_name=queue, receive_mode=recv_mode)
|
||||
queue_obj = self._add_queue_to_cache(cache_key, receiver=receiver)
|
||||
return queue_obj
|
||||
|
||||
def entity_name(self, name: str, table: Optional[Dict[int, int]] = None) -> str:
|
||||
"""Format AMQP queue name into a valid ServiceBus queue name."""
|
||||
return str(safe_str(name)).translate(table)
|
||||
return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
|
||||
|
||||
def _new_queue(self, queue, **kwargs):
|
||||
def _restore(self, message: virtual.base.Message) -> None:
|
||||
# Not be needed as ASB handles unacked messages
|
||||
# Remove 'azure_message' as its not JSON serializable
|
||||
# message.delivery_info.pop('azure_message', None)
|
||||
# super()._restore(message)
|
||||
pass
|
||||
|
||||
def _new_queue(self, queue: str, **kwargs) -> SendReceive:
|
||||
"""Ensure a queue exists in ServiceBus."""
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
try:
|
||||
return self._queue_cache[queue]
|
||||
except KeyError:
|
||||
self.queue_service.create_queue(queue, fail_on_exist=False)
|
||||
q = self._queue_cache[queue] = self.queue_service.get_queue(queue)
|
||||
return q
|
||||
# Converts seconds into ISO8601 duration format ie 66seconds = P1M6S
|
||||
lock_duration = isodate.duration_isoformat(isodate.Duration(seconds=self.peek_lock_seconds))
|
||||
try:
|
||||
self.queue_mgmt_service.create_queue(queue_name=queue, lock_duration=lock_duration)
|
||||
except azure.core.exceptions.ResourceExistsError:
|
||||
pass
|
||||
return self._add_queue_to_cache(queue)
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
def _delete(self, queue: str, *args, **kwargs) -> None:
|
||||
"""Delete queue by name."""
|
||||
queue_name = self.entity_name(queue)
|
||||
self._queue_cache.pop(queue_name, None)
|
||||
self.queue_service.delete_queue(queue_name)
|
||||
super()._delete(queue_name)
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
self._queue_mgmt_service.delete_queue(queue)
|
||||
send_receive_obj = self._queue_cache.pop(queue, None)
|
||||
if send_receive_obj:
|
||||
send_receive_obj.close()
|
||||
|
||||
def _put(self, queue: str, message, **kwargs) -> None:
|
||||
"""Put message onto queue."""
|
||||
msg = Message(dumps(message))
|
||||
self.queue_service.send_queue_message(self.entity_name(queue), msg)
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
msg = ServiceBusMessage(dumps(message))
|
||||
|
||||
def _get(self, queue, timeout=None):
|
||||
queue_obj = self._get_asb_sender(queue)
|
||||
queue_obj.sender.send_messages(msg)
|
||||
|
||||
def _get(self, queue: str, timeout: Optional[Union[float, int]] = None) -> Dict[str, Any]:
|
||||
"""Try to retrieve a single message off ``queue``."""
|
||||
message = self.queue_service.receive_queue_message(
|
||||
self.entity_name(queue),
|
||||
timeout=timeout or self.wait_time_seconds,
|
||||
peek_lock=self.peek_lock
|
||||
)
|
||||
# If we're not ack'ing for this queue, just change receive_mode
|
||||
recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE if queue in self._noack_queues else \
|
||||
ServiceBusReceiveMode.PEEK_LOCK
|
||||
|
||||
if message.body is None:
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
queue_obj = self._get_asb_receiver(queue, recv_mode)
|
||||
messages = queue_obj.receiver.receive_messages(max_message_count=1,
|
||||
max_wait_time=timeout or self.wait_time_seconds)
|
||||
|
||||
if not messages:
|
||||
raise Empty()
|
||||
|
||||
return loads(bytes_to_str(message.body))
|
||||
# message.body is either byte or generator[bytes]
|
||||
message = messages[0]
|
||||
if not isinstance(message.body, bytes):
|
||||
body = b''.join(message.body)
|
||||
else:
|
||||
body = message.body
|
||||
|
||||
def _size(self, queue):
|
||||
msg = loads(bytes_to_str(body))
|
||||
msg['properties']['delivery_info']['azure_message'] = message
|
||||
|
||||
return msg
|
||||
|
||||
def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
|
||||
if delivery_info['exchange'] in self._noack_queues:
|
||||
return super().basic_ack(delivery_tag)
|
||||
|
||||
queue = self.entity_name(self.queue_name_prefix + delivery_info['exchange'])
|
||||
queue_obj = self._get_asb_receiver(queue) # recv_mode is PEEK_LOCK when ack'ing messages
|
||||
|
||||
try:
|
||||
queue_obj.receiver.complete_message(delivery_info['azure_message'])
|
||||
except azure.servicebus.exceptions.MessageAlreadySettled:
|
||||
super().basic_ack(delivery_tag)
|
||||
except Exception:
|
||||
super().basic_reject(delivery_tag)
|
||||
else:
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _size(self, queue: str) -> int:
|
||||
"""Return the number of messages in a queue."""
|
||||
return self._new_queue(queue).message_count
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
props = self.queue_mgmt_service.get_queue_runtime_properties(queue)
|
||||
|
||||
return props.total_message_count
|
||||
|
||||
def _purge(self, queue):
|
||||
"""Delete all current messages in a queue."""
|
||||
# Azure doesn't provide a purge api yet
|
||||
n = 0
|
||||
max_purge_count = 10
|
||||
queue = self.entity_name(self.queue_name_prefix + queue)
|
||||
|
||||
# By default all the receivers will be in PEEK_LOCK receive mode
|
||||
queue_obj = self._queue_cache.get(queue, None)
|
||||
if queue not in self._noack_queues or queue_obj is None or queue_obj.receiver is None:
|
||||
queue_obj = self._get_asb_receiver(queue, ServiceBusReceiveMode.RECEIVE_AND_DELETE, 'purge_' + queue)
|
||||
|
||||
while True:
|
||||
message = self.queue_service.read_delete_queue_message(
|
||||
self.entity_name(queue), timeout=0.1)
|
||||
messages = queue_obj.receiver.receive_messages(max_message_count=max_purge_count,
|
||||
max_wait_time=0.2)
|
||||
n += len(messages)
|
||||
|
||||
if not message.body:
|
||||
if len(messages) < max_purge_count:
|
||||
break
|
||||
else:
|
||||
n += 1
|
||||
|
||||
return n
|
||||
|
||||
@property
|
||||
def queue_service(self):
|
||||
if self._queue_service is None:
|
||||
self._queue_service = ServiceBusService(
|
||||
service_namespace=self.conninfo.hostname,
|
||||
shared_access_key_name=self.conninfo.userid,
|
||||
shared_access_key_value=self.conninfo.password)
|
||||
def close(self) -> None:
|
||||
# receivers and senders spawn threads so clean them up
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
for queue_obj in self._queue_cache.values():
|
||||
queue_obj.close()
|
||||
self._queue_cache.clear()
|
||||
|
||||
if self.connection is not None:
|
||||
self.connection.close_channel(self)
|
||||
|
||||
@property
|
||||
def queue_service(self) -> ServiceBusClient:
|
||||
if self._queue_service is None:
|
||||
self._queue_service = ServiceBusClient.from_connection_string(self._connection_string)
|
||||
return self._queue_service
|
||||
|
||||
@property
|
||||
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
|
||||
if self._queue_mgmt_service is None:
|
||||
self._queue_mgmt_service = ServiceBusAdministrationClient.from_connection_string(self._connection_string)
|
||||
return self._queue_mgmt_service
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
@ -160,23 +310,19 @@ class Channel(virtual.Channel):
|
|||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def visibility_timeout(self):
|
||||
return (self.transport_options.get('visibility_timeout') or
|
||||
self.default_visibility_timeout)
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
def queue_name_prefix(self) -> str:
|
||||
return self.transport_options.get('queue_name_prefix', '')
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
def wait_time_seconds(self) -> int:
|
||||
return self.transport_options.get('wait_time_seconds',
|
||||
self.default_wait_time_seconds)
|
||||
|
||||
@cached_property
|
||||
def peek_lock(self):
|
||||
return self.transport_options.get('peek_lock',
|
||||
self.default_peek_lock)
|
||||
def peek_lock_seconds(self) -> int:
|
||||
return min(self.transport_options.get('peek_lock_seconds',
|
||||
self.default_peek_lock_seconds),
|
||||
300) # Limit upper bounds to 300
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
|
@ -186,3 +332,4 @@ class Transport(virtual.Transport):
|
|||
|
||||
polling_interval = 1
|
||||
default_port = None
|
||||
can_parse_url = True
|
||||
|
|
|
@ -1 +1 @@
|
|||
azure-servicebus>=0.21.1
|
||||
azure-servicebus>=7.0.0
|
||||
|
|
|
@ -1,244 +1,277 @@
|
|||
import json
|
||||
import pytest
|
||||
import base64
|
||||
import random
|
||||
from queue import Empty
|
||||
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, MagicMock
|
||||
from collections import namedtuple
|
||||
from kombu import messaging
|
||||
from kombu import Connection, Exchange, Queue
|
||||
|
||||
from kombu.transport import azureservicebus
|
||||
|
||||
import azure.servicebus.exceptions
|
||||
import azure.core.exceptions
|
||||
pytest.importorskip('azure.servicebus')
|
||||
|
||||
try:
|
||||
# azure-servicebus version >= 0.50.0
|
||||
from azure.servicebus.control_client import Message, ServiceBusService
|
||||
except ImportError:
|
||||
try:
|
||||
# azure-servicebus version <= 0.21.1
|
||||
from azure.servicebus import Message, ServiceBusService
|
||||
except ImportError:
|
||||
ServiceBusService = Message = None
|
||||
from azure.servicebus import ServiceBusMessage, ServiceBusReceiveMode
|
||||
|
||||
|
||||
class QueueMock:
|
||||
""" Hold information about a queue. """
|
||||
class ASBQueue:
|
||||
def __init__(self, kwargs):
|
||||
self.options = kwargs
|
||||
self.items = []
|
||||
self.waiting_ack = []
|
||||
self.send_calls = []
|
||||
self.recv_calls = []
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.messages = []
|
||||
self.message_count = len(self.messages)
|
||||
def get_receiver(self, kwargs):
|
||||
receive_mode = kwargs.get('receive_mode', ServiceBusReceiveMode.PEEK_LOCK)
|
||||
|
||||
def __repr__(self):
|
||||
return 'QueueMock: {} messages'.format(len(self.messages))
|
||||
class Receiver:
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def receive_messages(_self, **kwargs2):
|
||||
max_message_count = kwargs2.get('max_message_count', 1)
|
||||
result = []
|
||||
if self.items:
|
||||
while self.items or len(result) > max_message_count:
|
||||
item = self.items.pop(0)
|
||||
if receive_mode is ServiceBusReceiveMode.PEEK_LOCK:
|
||||
self.waiting_ack.append(item)
|
||||
result.append(item)
|
||||
|
||||
self.recv_calls.append({
|
||||
'receiver_options': kwargs,
|
||||
'receive_messages_options': kwargs2,
|
||||
'messages': result
|
||||
})
|
||||
return result
|
||||
return Receiver()
|
||||
|
||||
def get_sender(self):
|
||||
class Sender:
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def send_messages(_self, msg):
|
||||
self.send_calls.append(msg)
|
||||
self.items.append(msg)
|
||||
return Sender()
|
||||
|
||||
|
||||
def _create_mock_connection(url='', **kwargs):
|
||||
|
||||
class _Channel(azureservicebus.Channel):
|
||||
# reset _fanout_queues for each instance
|
||||
queues = []
|
||||
_queue_service = None
|
||||
|
||||
def list_queues(self):
|
||||
return self.queues
|
||||
|
||||
@property
|
||||
def queue_service(self):
|
||||
if self._queue_service is None:
|
||||
self._queue_service = AzureServiceBusClientMock()
|
||||
return self._queue_service
|
||||
|
||||
class Transport(azureservicebus.Transport):
|
||||
Channel = _Channel
|
||||
|
||||
return Connection(url, transport=Transport, **kwargs)
|
||||
|
||||
|
||||
class AzureServiceBusClientMock:
|
||||
|
||||
class ASBMock:
|
||||
def __init__(self):
|
||||
"""
|
||||
Imitate the ServiceBus Client.
|
||||
"""
|
||||
# queues doesn't exist on the real client, here for testing.
|
||||
self.queues = []
|
||||
self._queue_cache = {}
|
||||
self.queues.append(self.create_queue(queue_name='unittest_queue'))
|
||||
self.queues = {}
|
||||
|
||||
def create_queue(self, queue_name, queue=None, fail_on_exist=False):
|
||||
queue = QueueMock(name=queue_name)
|
||||
self.queues.append(queue)
|
||||
self._queue_cache[queue_name] = queue
|
||||
return queue
|
||||
def get_queue_receiver(self, queue_name, **kwargs):
|
||||
return self.queues[queue_name].get_receiver(kwargs)
|
||||
|
||||
def get_queue(self, queue_name=None):
|
||||
for queue in self.queues:
|
||||
if queue.name == queue_name:
|
||||
return queue
|
||||
|
||||
def list_queues(self):
|
||||
return self.queues
|
||||
|
||||
def send_queue_message(self, queue_name=None, message=None):
|
||||
queue = self.get_queue(queue_name)
|
||||
queue.messages.append(message)
|
||||
|
||||
def receive_queue_message(self, queue_name, peek_lock=True, timeout=60):
|
||||
queue = self.get_queue(queue_name)
|
||||
if queue and len(queue.messages):
|
||||
return queue.messages.pop(0)
|
||||
return Message()
|
||||
|
||||
def read_delete_queue_message(self, queue_name, timeout='60'):
|
||||
return self.receive_queue_message(queue_name, timeout=timeout)
|
||||
|
||||
def delete_queue(self, queue_name=None):
|
||||
queue = self.get_queue(queue_name)
|
||||
if queue:
|
||||
del queue
|
||||
def get_queue_sender(self, queue_name):
|
||||
return self.queues[queue_name].get_sender()
|
||||
|
||||
|
||||
class test_Channel:
|
||||
class ASBMgmtMock:
|
||||
def __init__(self, queues):
|
||||
self.queues = queues
|
||||
|
||||
def handleMessageCallback(self, message):
|
||||
self.callback_message = message
|
||||
def create_queue(self, queue_name, **kwargs):
|
||||
if queue_name in self.queues:
|
||||
raise azure.core.exceptions.ResourceExistsError()
|
||||
self.queues[queue_name] = ASBQueue(kwargs)
|
||||
|
||||
def setup(self):
|
||||
self.url = 'azureservicebus://'
|
||||
self.queue_name = 'unittest_queue'
|
||||
def delete_queue(self, queue_name):
|
||||
self.queues.pop(queue_name, None)
|
||||
|
||||
self.exchange = Exchange('test_servicebus', type='direct')
|
||||
self.queue = Queue(self.queue_name, self.exchange, self.queue_name)
|
||||
self.connection = _create_mock_connection(self.url)
|
||||
self.channel = self.connection.default_channel
|
||||
self.queue(self.channel).declare()
|
||||
def get_queue_runtime_properties(self, queue_name):
|
||||
count = len(self.queues[queue_name].items)
|
||||
mock = MagicMock()
|
||||
mock.total_message_count = count
|
||||
return mock
|
||||
|
||||
self.producer = messaging.Producer(self.channel,
|
||||
self.exchange,
|
||||
routing_key=self.queue_name)
|
||||
|
||||
self.channel.basic_consume(self.queue_name,
|
||||
no_ack=False,
|
||||
callback=self.handleMessageCallback,
|
||||
consumer_tag='unittest')
|
||||
URL_NOCREDS = 'azureservicebus://'
|
||||
URL_CREDS = 'azureservicebus://policyname:ke/y@hostname'
|
||||
|
||||
def teardown(self):
|
||||
# Removes QoS reserved messages so we don't restore msgs on shutdown.
|
||||
try:
|
||||
qos = self.channel._qos
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if qos:
|
||||
qos._dirty.clear()
|
||||
qos._delivered.clear()
|
||||
|
||||
def test_queue_service(self):
|
||||
# Test gettings queue service without credentials
|
||||
conn = Connection(self.url, transport=azureservicebus.Transport)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
conn.channel()
|
||||
assert exc == 'You need to provide servicebus namespace'
|
||||
def test_queue_service_nocredentials():
|
||||
conn = Connection(URL_NOCREDS, transport=azureservicebus.Transport)
|
||||
with pytest.raises(ValueError) as exc:
|
||||
conn.channel()
|
||||
assert exc == 'Need an URI like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}'
|
||||
|
||||
# Test getting queue service when queue_service is not setted
|
||||
with patch('kombu.transport.azureservicebus.ServiceBusService') as m:
|
||||
channel = conn.channel()
|
||||
|
||||
# Remove queue service to get from service bus again
|
||||
channel._queue_service = None
|
||||
channel.queue_service
|
||||
def test_queue_service():
|
||||
# Test gettings queue service without credentials
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
|
||||
with patch('kombu.transport.azureservicebus.ServiceBusClient') as m:
|
||||
channel = conn.channel()
|
||||
|
||||
assert m.call_count == 2
|
||||
# Check the SAS token "ke/y" has been parsed from the url correctly
|
||||
assert channel._sas_key == 'ke/y'
|
||||
|
||||
# Calling queue_service again needs to reuse ServiceBus instance
|
||||
channel.queue_service
|
||||
assert m.call_count == 2
|
||||
m.from_connection_string.return_value = 'test'
|
||||
|
||||
def test_conninfo(self):
|
||||
conninfo = self.channel.conninfo
|
||||
assert conninfo is self.connection
|
||||
# Remove queue service to get from service bus again
|
||||
channel._queue_service = None
|
||||
assert channel.queue_service == 'test'
|
||||
assert m.from_connection_string.call_count == 1
|
||||
|
||||
def test_transport_type(self):
|
||||
transport_options = self.channel.transport_options
|
||||
assert transport_options == {}
|
||||
# Ensure that queue_service is cached
|
||||
assert channel.queue_service == 'test'
|
||||
assert m.from_connection_string.call_count == 1
|
||||
|
||||
def test_visibility_timeout(self):
|
||||
# Test getting default visibility timeout
|
||||
assert (
|
||||
self.channel.visibility_timeout ==
|
||||
azureservicebus.Channel.default_visibility_timeout
|
||||
)
|
||||
|
||||
# Test getting value setted in transport options
|
||||
del self.channel.visibility_timeout
|
||||
self.channel.transport_options['visibility_timeout'] = 10
|
||||
assert self.channel.visibility_timeout == 10
|
||||
def test_conninfo():
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
|
||||
channel = conn.channel()
|
||||
assert channel.conninfo is conn
|
||||
|
||||
def test_wait_timeout_seconds(self):
|
||||
# Test getting default wait timeout seconds
|
||||
assert (
|
||||
self.channel.wait_time_seconds ==
|
||||
azureservicebus.Channel.default_wait_time_seconds
|
||||
)
|
||||
|
||||
# Test getting value setted in transport options
|
||||
del self.channel.wait_time_seconds
|
||||
self.channel.transport_options['wait_time_seconds'] = 10
|
||||
assert self.channel.wait_time_seconds == 10
|
||||
def test_transport_type():
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
|
||||
channel = conn.channel()
|
||||
assert not channel.transport_options
|
||||
|
||||
def test_peek_lock(self):
|
||||
# Test getting default peek lock
|
||||
assert (
|
||||
self.channel.peek_lock ==
|
||||
azureservicebus.Channel.default_peek_lock
|
||||
)
|
||||
|
||||
# Test getting value setted in transport options
|
||||
del self.channel.peek_lock
|
||||
self.channel.transport_options['peek_lock'] = True
|
||||
assert self.channel.peek_lock is True
|
||||
def test_default_wait_timeout_seconds():
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
|
||||
channel = conn.channel()
|
||||
|
||||
def test_get_from_azure(self):
|
||||
# Test getting a single message
|
||||
message = 'my test message'
|
||||
self.producer.publish(message)
|
||||
result = self.channel._get(self.queue_name)
|
||||
assert 'body' in result.keys()
|
||||
assert channel.wait_time_seconds == azureservicebus.Channel.default_wait_time_seconds
|
||||
|
||||
# Test getting multiple messages
|
||||
for i in range(3):
|
||||
message = f'message: {i}'
|
||||
self.producer.publish(message)
|
||||
|
||||
queue_service = self.channel.queue_service
|
||||
assert len(queue_service.get_queue(self.queue_name).messages) == 3
|
||||
def test_custom_wait_timeout_seconds():
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport, transport_options={'wait_time_seconds': 10})
|
||||
channel = conn.channel()
|
||||
|
||||
for i in range(3):
|
||||
result = self.channel._get(self.queue_name)
|
||||
assert channel.wait_time_seconds == 10
|
||||
|
||||
assert len(queue_service.get_queue(self.queue_name).messages) == 0
|
||||
|
||||
def test_get_with_empty_list(self):
|
||||
with pytest.raises(Empty):
|
||||
self.channel._get(self.queue_name)
|
||||
def test_default_peek_lock_seconds():
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
|
||||
channel = conn.channel()
|
||||
|
||||
def test_put_and_get(self):
|
||||
message = 'my test message'
|
||||
self.producer.publish(message)
|
||||
results = self.queue(self.channel).get().payload
|
||||
assert message == results
|
||||
assert channel.peek_lock_seconds == azureservicebus.Channel.default_peek_lock_seconds
|
||||
|
||||
def test_delete_queue(self):
|
||||
# Test deleting queue without message
|
||||
queue_name = 'new_unittest_queue'
|
||||
self.channel._new_queue(queue_name)
|
||||
|
||||
assert queue_name in self.channel._queue_cache
|
||||
self.channel._delete(queue_name)
|
||||
assert queue_name not in self.channel._queue_cache
|
||||
def test_custom_peek_lock_seconds():
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
|
||||
transport_options={'peek_lock_seconds': 65})
|
||||
channel = conn.channel()
|
||||
|
||||
# Test deleting queue with message
|
||||
message = 'my test message'
|
||||
self.producer.publish(message)
|
||||
self.channel._delete(self.queue_name)
|
||||
assert queue_name not in self.channel._queue_cache
|
||||
assert channel.peek_lock_seconds == 65
|
||||
|
||||
|
||||
def test_invalid_peek_lock_seconds():
|
||||
# Max is 300
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
|
||||
transport_options={'peek_lock_seconds': 900})
|
||||
channel = conn.channel()
|
||||
|
||||
assert channel.peek_lock_seconds == 300
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_queue():
|
||||
return 'azureservicebus_queue_{0}'.format(random.randint(1000,9999))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_asb():
|
||||
return ASBMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_asb_management(mock_asb):
|
||||
return ASBMgmtMock(queues=mock_asb.queues)
|
||||
|
||||
|
||||
MockQueue = namedtuple('MockQueue', ['queue_name', 'asb', 'asb_mgmt', 'conn', 'channel', 'producer', 'queue'])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue(mock_asb, mock_asb_management, random_queue) -> MockQueue:
|
||||
exchange = Exchange('test_servicebus', type='direct')
|
||||
queue = Queue(random_queue, exchange, random_queue)
|
||||
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
|
||||
channel = conn.channel()
|
||||
channel._queue_service = mock_asb
|
||||
channel._queue_mgmt_service = mock_asb_management
|
||||
|
||||
queue(channel).declare()
|
||||
producer = messaging.Producer(channel, exchange, routing_key=random_queue)
|
||||
|
||||
return MockQueue(
|
||||
random_queue,
|
||||
mock_asb,
|
||||
mock_asb_management,
|
||||
conn,
|
||||
channel,
|
||||
producer,
|
||||
queue
|
||||
)
|
||||
|
||||
|
||||
def test_basic_put_get(mock_queue: MockQueue):
|
||||
text_message = "test message"
|
||||
|
||||
# This ends up hitting channel._put
|
||||
mock_queue.producer.publish(text_message)
|
||||
|
||||
assert len(mock_queue.asb.queues[mock_queue.queue_name].items) == 1
|
||||
azure_msg = mock_queue.asb.queues[mock_queue.queue_name].items[0]
|
||||
assert isinstance(azure_msg, ServiceBusMessage)
|
||||
|
||||
message = mock_queue.channel._get(mock_queue.queue_name)
|
||||
azure_msg_decoded = json.loads(str(azure_msg))
|
||||
|
||||
assert message['body'] == azure_msg_decoded['body']
|
||||
|
||||
# Check the message has been annotated with the azure message object
|
||||
# which is used to ack later
|
||||
assert message['properties']['delivery_info']['azure_message'] is azure_msg
|
||||
|
||||
assert base64.b64decode(message['body']).decode() == text_message
|
||||
|
||||
# Ack is on by default, check an ack is waiting
|
||||
assert len(mock_queue.asb.queues[mock_queue.queue_name].waiting_ack) == 1
|
||||
|
||||
|
||||
def test_empty_queue_get(mock_queue: MockQueue):
|
||||
with pytest.raises(Empty):
|
||||
mock_queue.channel._get(mock_queue.queue_name)
|
||||
|
||||
|
||||
def test_delete_empty_queue(mock_queue: MockQueue):
|
||||
chan = mock_queue.channel
|
||||
queue_name = 'random_queue_{0}'.format(random.randint(1000, 9999))
|
||||
|
||||
chan._new_queue(queue_name)
|
||||
assert queue_name in chan._queue_cache
|
||||
chan._delete(queue_name)
|
||||
assert queue_name not in chan._queue_cache
|
||||
|
||||
|
||||
def test_delete_populated_queue(mock_queue: MockQueue):
|
||||
mock_queue.producer.publish('test1234')
|
||||
|
||||
mock_queue.channel._delete(mock_queue.queue_name)
|
||||
assert mock_queue.queue_name not in mock_queue.channel._queue_cache
|
||||
|
||||
|
||||
def test_purge(mock_queue: MockQueue):
|
||||
mock_queue.producer.publish('test1234')
|
||||
mock_queue.producer.publish('test1234')
|
||||
mock_queue.producer.publish('test1234')
|
||||
mock_queue.producer.publish('test1234')
|
||||
|
||||
size = mock_queue.channel._size(mock_queue.queue_name)
|
||||
assert size == 4
|
||||
|
||||
assert mock_queue.channel._purge(mock_queue.queue_name) == 4
|
||||
|
||||
size = mock_queue.channel._size(mock_queue.queue_name)
|
||||
assert size == 0
|
||||
assert len(mock_queue.asb.queues[mock_queue.queue_name].waiting_ack) == 0
|
||||
|
|
Loading…
Reference in New Issue