Support for Azure Service Bus 7.0.0 (#1284)

* Started servicebus refactor

* Cleaned up, handle service bus SAS token parsing
This commit is contained in:
Terry Cain 2021-01-04 13:14:39 +00:00 committed by GitHub
parent a37a05616f
commit 3d41ab1389
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 450 additions and 270 deletions

View File

@ -4,9 +4,6 @@ Note that the Shared Access Policy used to connect to Azure Service Bus
requires Manage, Send and Listen claims since the broker will create new
queues and delete old queues as required.
Note that if the SAS key for the Service Bus account contains a slash, it will
have to be regenerated before it can be used in the connection URL.
More information about Azure Service Bus:
https://azure.microsoft.com/en-us/services/service-bus/
@ -31,31 +28,28 @@ Connection string has the following format:
Transport Options
=================
* ``visibility_timeout``
* ``queue_name_prefix``
* ``wait_time_seconds``
* ``peek_lock``
* ``queue_name_prefix`` - String prefix to prepend to queue names in a service bus namespace
* ``wait_time_seconds`` - Number of seconds to wait to receive messages. Default ``5``
* ``peek_lock_seconds`` - Number of seconds the message is visible for before it is requeued
and sent to another consumer. Default ``60``
"""
import string
from queue import Empty
from typing import Dict, Any, Optional, Union, Set
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import loads, dumps
from kombu.utils.objects import cached_property
from . import virtual
import azure.core.exceptions
import azure.servicebus.exceptions
import isodate
from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusReceiver, ServiceBusSender, \
ServiceBusReceiveMode
from azure.servicebus.management import ServiceBusAdministrationClient
try:
# azure-servicebus version <= 0.21.1
from azure.servicebus import ServiceBusService, Message, Queue
except ImportError:
try:
# azure-servicebus version >= 0.50.0
from azure.servicebus.control_client import \
ServiceBusService, Message, Queue
except ImportError:
ServiceBusService = Message = Queue = None
from . import virtual
# dots are replaced by dash, all other punctuation replaced by underscore.
CHARS_REPLACE_TABLE = {
@ -63,94 +57,250 @@ CHARS_REPLACE_TABLE = {
}
class SendReceive:
def __init__(self, receiver: Optional[ServiceBusReceiver] = None, sender: Optional[ServiceBusSender] = None):
self.receiver = receiver # type: ServiceBusReceiver
self.sender = sender # type: ServiceBusSender
def close(self) -> None:
if self.receiver:
self.receiver.close()
self.receiver = None
if self.sender:
self.sender.close()
self.sender = None
class Channel(virtual.Channel):
"""Azure Service Bus channel."""
default_visibility_timeout = 1800 # 30 minutes.
default_wait_time_seconds = 5 # in seconds
default_peek_lock = False
default_peek_lock_seconds = 60 # in seconds (default 60, max 300)
domain_format = 'kombu%(vhost)s'
_queue_service = None
_queue_cache = {}
_queue_service = None # type: ServiceBusClient
_queue_mgmt_service = None # type: ServiceBusAdministrationClient
_queue_cache = {} # type: Dict[str, SendReceive]
_noack_queues = set() # type: Set[str]
def __init__(self, *args, **kwargs):
if ServiceBusService is None:
raise ImportError('Azure Service Bus transport requires the '
'azure-servicebus library')
super().__init__(*args, **kwargs)
for queue in self.queue_service.list_queues():
self._queue_cache[queue] = queue
self._namespace = None
self._policy = None
self._sas_key = None
self._connection_string = None
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
self._try_parse_connection_string()
self.qos.restore_at_shutdown = False
def _try_parse_connection_string(self) -> None:
# URL like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
# urllib parse does not work as the sas key could contain a slash
# e.g. azureservicebus://rootpolicy:some/key@somenamespace
uri = self.conninfo.hostname.replace('azureservicebus://', '') # > 'rootpolicy:some/key@somenamespace'
policykeypair, self._namespace = uri.rsplit('@', 1) # > 'rootpolicy:some/key', 'somenamespace'
self._policy, self._sas_key = policykeypair.split(':', 1) # > 'rootpolicy', 'some/key'
# Validate ASB connection string
if not all([self._namespace, self._policy, self._sas_key]):
raise ValueError('Need an URI like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}')
# Convert
endpoint = 'sb://' + self._namespace
if not endpoint.endswith('.net'):
endpoint += '.servicebus.windows.net'
conn_dict = {
'Endpoint': endpoint,
'SharedAccessKeyName': self._policy,
'SharedAccessKey': self._sas_key,
}
self._connection_string = ';'.join([key + '=' + value for key, value in conn_dict.items()])
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
return super().basic_consume(
queue, no_ack, *args, **kwargs
)
def basic_cancel(self, consumer_tag):
if consumer_tag in self._consumers:
queue = self._tag_to_queue[consumer_tag]
self._noack_queues.discard(queue)
return super().basic_cancel(consumer_tag)
def _add_queue_to_cache(self,
name: str,
receiver: Optional[ServiceBusReceiver] = None,
sender: Optional[ServiceBusSender] = None) -> SendReceive:
if name in self._queue_cache:
obj = self._queue_cache[name]
obj.sender = obj.sender or sender
obj.receiver = obj.receiver or receiver
else:
obj = SendReceive(receiver, sender)
self._queue_cache[name] = obj
return obj
def _get_asb_sender(self, queue: str) -> SendReceive:
queue_obj = self._queue_cache.get(queue, None)
if queue_obj is None or queue_obj.sender is None:
sender = self.queue_service.get_queue_sender(queue)
queue_obj = self._add_queue_to_cache(queue, sender=sender)
return queue_obj
def _get_asb_receiver(self, queue: str,
recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
queue_cache_key: Optional[str] = None) -> SendReceive:
cache_key = queue_cache_key or queue
queue_obj = self._queue_cache.get(cache_key, None)
if queue_obj is None or queue_obj.receiver is None:
receiver = self.queue_service.get_queue_receiver(queue_name=queue, receive_mode=recv_mode)
queue_obj = self._add_queue_to_cache(cache_key, receiver=receiver)
return queue_obj
def entity_name(self, name: str, table: Optional[Dict[int, int]] = None) -> str:
"""Format AMQP queue name into a valid ServiceBus queue name."""
return str(safe_str(name)).translate(table)
return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
def _new_queue(self, queue, **kwargs):
def _restore(self, message: virtual.base.Message) -> None:
# Not be needed as ASB handles unacked messages
# Remove 'azure_message' as its not JSON serializable
# message.delivery_info.pop('azure_message', None)
# super()._restore(message)
pass
def _new_queue(self, queue: str, **kwargs) -> SendReceive:
"""Ensure a queue exists in ServiceBus."""
queue = self.entity_name(self.queue_name_prefix + queue)
try:
return self._queue_cache[queue]
except KeyError:
self.queue_service.create_queue(queue, fail_on_exist=False)
q = self._queue_cache[queue] = self.queue_service.get_queue(queue)
return q
# Converts seconds into ISO8601 duration format ie 66seconds = P1M6S
lock_duration = isodate.duration_isoformat(isodate.Duration(seconds=self.peek_lock_seconds))
try:
self.queue_mgmt_service.create_queue(queue_name=queue, lock_duration=lock_duration)
except azure.core.exceptions.ResourceExistsError:
pass
return self._add_queue_to_cache(queue)
def _delete(self, queue, *args, **kwargs):
def _delete(self, queue: str, *args, **kwargs) -> None:
"""Delete queue by name."""
queue_name = self.entity_name(queue)
self._queue_cache.pop(queue_name, None)
self.queue_service.delete_queue(queue_name)
super()._delete(queue_name)
queue = self.entity_name(self.queue_name_prefix + queue)
def _put(self, queue, message, **kwargs):
self._queue_mgmt_service.delete_queue(queue)
send_receive_obj = self._queue_cache.pop(queue, None)
if send_receive_obj:
send_receive_obj.close()
def _put(self, queue: str, message, **kwargs) -> None:
"""Put message onto queue."""
msg = Message(dumps(message))
self.queue_service.send_queue_message(self.entity_name(queue), msg)
queue = self.entity_name(self.queue_name_prefix + queue)
msg = ServiceBusMessage(dumps(message))
def _get(self, queue, timeout=None):
queue_obj = self._get_asb_sender(queue)
queue_obj.sender.send_messages(msg)
def _get(self, queue: str, timeout: Optional[Union[float, int]] = None) -> Dict[str, Any]:
"""Try to retrieve a single message off ``queue``."""
message = self.queue_service.receive_queue_message(
self.entity_name(queue),
timeout=timeout or self.wait_time_seconds,
peek_lock=self.peek_lock
)
# If we're not ack'ing for this queue, just change receive_mode
recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE if queue in self._noack_queues else \
ServiceBusReceiveMode.PEEK_LOCK
if message.body is None:
queue = self.entity_name(self.queue_name_prefix + queue)
queue_obj = self._get_asb_receiver(queue, recv_mode)
messages = queue_obj.receiver.receive_messages(max_message_count=1,
max_wait_time=timeout or self.wait_time_seconds)
if not messages:
raise Empty()
return loads(bytes_to_str(message.body))
# message.body is either byte or generator[bytes]
message = messages[0]
if not isinstance(message.body, bytes):
body = b''.join(message.body)
else:
body = message.body
def _size(self, queue):
msg = loads(bytes_to_str(body))
msg['properties']['delivery_info']['azure_message'] = message
return msg
def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
delivery_info = self.qos.get(delivery_tag).delivery_info
if delivery_info['exchange'] in self._noack_queues:
return super().basic_ack(delivery_tag)
queue = self.entity_name(self.queue_name_prefix + delivery_info['exchange'])
queue_obj = self._get_asb_receiver(queue) # recv_mode is PEEK_LOCK when ack'ing messages
try:
queue_obj.receiver.complete_message(delivery_info['azure_message'])
except azure.servicebus.exceptions.MessageAlreadySettled:
super().basic_ack(delivery_tag)
except Exception:
super().basic_reject(delivery_tag)
else:
super().basic_ack(delivery_tag)
def _size(self, queue: str) -> int:
"""Return the number of messages in a queue."""
return self._new_queue(queue).message_count
queue = self.entity_name(self.queue_name_prefix + queue)
props = self.queue_mgmt_service.get_queue_runtime_properties(queue)
return props.total_message_count
def _purge(self, queue):
"""Delete all current messages in a queue."""
# Azure doesn't provide a purge api yet
n = 0
max_purge_count = 10
queue = self.entity_name(self.queue_name_prefix + queue)
# By default all the receivers will be in PEEK_LOCK receive mode
queue_obj = self._queue_cache.get(queue, None)
if queue not in self._noack_queues or queue_obj is None or queue_obj.receiver is None:
queue_obj = self._get_asb_receiver(queue, ServiceBusReceiveMode.RECEIVE_AND_DELETE, 'purge_' + queue)
while True:
message = self.queue_service.read_delete_queue_message(
self.entity_name(queue), timeout=0.1)
messages = queue_obj.receiver.receive_messages(max_message_count=max_purge_count,
max_wait_time=0.2)
n += len(messages)
if not message.body:
if len(messages) < max_purge_count:
break
else:
n += 1
return n
@property
def queue_service(self):
if self._queue_service is None:
self._queue_service = ServiceBusService(
service_namespace=self.conninfo.hostname,
shared_access_key_name=self.conninfo.userid,
shared_access_key_value=self.conninfo.password)
def close(self) -> None:
# receivers and senders spawn threads so clean them up
if not self.closed:
self.closed = True
for queue_obj in self._queue_cache.values():
queue_obj.close()
self._queue_cache.clear()
if self.connection is not None:
self.connection.close_channel(self)
@property
def queue_service(self) -> ServiceBusClient:
if self._queue_service is None:
self._queue_service = ServiceBusClient.from_connection_string(self._connection_string)
return self._queue_service
@property
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
if self._queue_mgmt_service is None:
self._queue_mgmt_service = ServiceBusAdministrationClient.from_connection_string(self._connection_string)
return self._queue_mgmt_service
@property
def conninfo(self):
return self.connection.client
@ -160,23 +310,19 @@ class Channel(virtual.Channel):
return self.connection.client.transport_options
@cached_property
def visibility_timeout(self):
return (self.transport_options.get('visibility_timeout') or
self.default_visibility_timeout)
@cached_property
def queue_name_prefix(self):
def queue_name_prefix(self) -> str:
return self.transport_options.get('queue_name_prefix', '')
@cached_property
def wait_time_seconds(self):
def wait_time_seconds(self) -> int:
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
@cached_property
def peek_lock(self):
return self.transport_options.get('peek_lock',
self.default_peek_lock)
def peek_lock_seconds(self) -> int:
return min(self.transport_options.get('peek_lock_seconds',
self.default_peek_lock_seconds),
300) # Limit upper bounds to 300
class Transport(virtual.Transport):
@ -186,3 +332,4 @@ class Transport(virtual.Transport):
polling_interval = 1
default_port = None
can_parse_url = True

View File

@ -1 +1 @@
azure-servicebus>=0.21.1
azure-servicebus>=7.0.0

View File

@ -1,244 +1,277 @@
import json
import pytest
import base64
import random
from queue import Empty
from unittest.mock import patch
from unittest.mock import patch, MagicMock
from collections import namedtuple
from kombu import messaging
from kombu import Connection, Exchange, Queue
from kombu.transport import azureservicebus
import azure.servicebus.exceptions
import azure.core.exceptions
pytest.importorskip('azure.servicebus')
try:
# azure-servicebus version >= 0.50.0
from azure.servicebus.control_client import Message, ServiceBusService
except ImportError:
try:
# azure-servicebus version <= 0.21.1
from azure.servicebus import Message, ServiceBusService
except ImportError:
ServiceBusService = Message = None
from azure.servicebus import ServiceBusMessage, ServiceBusReceiveMode
class QueueMock:
""" Hold information about a queue. """
class ASBQueue:
def __init__(self, kwargs):
self.options = kwargs
self.items = []
self.waiting_ack = []
self.send_calls = []
self.recv_calls = []
def __init__(self, name):
self.name = name
self.messages = []
self.message_count = len(self.messages)
def get_receiver(self, kwargs):
receive_mode = kwargs.get('receive_mode', ServiceBusReceiveMode.PEEK_LOCK)
def __repr__(self):
return 'QueueMock: {} messages'.format(len(self.messages))
class Receiver:
def close(self):
pass
def receive_messages(_self, **kwargs2):
max_message_count = kwargs2.get('max_message_count', 1)
result = []
if self.items:
while self.items or len(result) > max_message_count:
item = self.items.pop(0)
if receive_mode is ServiceBusReceiveMode.PEEK_LOCK:
self.waiting_ack.append(item)
result.append(item)
self.recv_calls.append({
'receiver_options': kwargs,
'receive_messages_options': kwargs2,
'messages': result
})
return result
return Receiver()
def get_sender(self):
class Sender:
def close(self):
pass
def send_messages(_self, msg):
self.send_calls.append(msg)
self.items.append(msg)
return Sender()
def _create_mock_connection(url='', **kwargs):
class _Channel(azureservicebus.Channel):
# reset _fanout_queues for each instance
queues = []
_queue_service = None
def list_queues(self):
return self.queues
@property
def queue_service(self):
if self._queue_service is None:
self._queue_service = AzureServiceBusClientMock()
return self._queue_service
class Transport(azureservicebus.Transport):
Channel = _Channel
return Connection(url, transport=Transport, **kwargs)
class AzureServiceBusClientMock:
class ASBMock:
def __init__(self):
"""
Imitate the ServiceBus Client.
"""
# queues doesn't exist on the real client, here for testing.
self.queues = []
self._queue_cache = {}
self.queues.append(self.create_queue(queue_name='unittest_queue'))
self.queues = {}
def create_queue(self, queue_name, queue=None, fail_on_exist=False):
queue = QueueMock(name=queue_name)
self.queues.append(queue)
self._queue_cache[queue_name] = queue
return queue
def get_queue_receiver(self, queue_name, **kwargs):
return self.queues[queue_name].get_receiver(kwargs)
def get_queue(self, queue_name=None):
for queue in self.queues:
if queue.name == queue_name:
return queue
def list_queues(self):
return self.queues
def send_queue_message(self, queue_name=None, message=None):
queue = self.get_queue(queue_name)
queue.messages.append(message)
def receive_queue_message(self, queue_name, peek_lock=True, timeout=60):
queue = self.get_queue(queue_name)
if queue and len(queue.messages):
return queue.messages.pop(0)
return Message()
def read_delete_queue_message(self, queue_name, timeout='60'):
return self.receive_queue_message(queue_name, timeout=timeout)
def delete_queue(self, queue_name=None):
queue = self.get_queue(queue_name)
if queue:
del queue
def get_queue_sender(self, queue_name):
return self.queues[queue_name].get_sender()
class test_Channel:
class ASBMgmtMock:
def __init__(self, queues):
self.queues = queues
def handleMessageCallback(self, message):
self.callback_message = message
def create_queue(self, queue_name, **kwargs):
if queue_name in self.queues:
raise azure.core.exceptions.ResourceExistsError()
self.queues[queue_name] = ASBQueue(kwargs)
def setup(self):
self.url = 'azureservicebus://'
self.queue_name = 'unittest_queue'
def delete_queue(self, queue_name):
self.queues.pop(queue_name, None)
self.exchange = Exchange('test_servicebus', type='direct')
self.queue = Queue(self.queue_name, self.exchange, self.queue_name)
self.connection = _create_mock_connection(self.url)
self.channel = self.connection.default_channel
self.queue(self.channel).declare()
def get_queue_runtime_properties(self, queue_name):
count = len(self.queues[queue_name].items)
mock = MagicMock()
mock.total_message_count = count
return mock
self.producer = messaging.Producer(self.channel,
self.exchange,
routing_key=self.queue_name)
self.channel.basic_consume(self.queue_name,
no_ack=False,
callback=self.handleMessageCallback,
consumer_tag='unittest')
URL_NOCREDS = 'azureservicebus://'
URL_CREDS = 'azureservicebus://policyname:ke/y@hostname'
def teardown(self):
# Removes QoS reserved messages so we don't restore msgs on shutdown.
try:
qos = self.channel._qos
except AttributeError:
pass
else:
if qos:
qos._dirty.clear()
qos._delivered.clear()
def test_queue_service(self):
# Test gettings queue service without credentials
conn = Connection(self.url, transport=azureservicebus.Transport)
with pytest.raises(ValueError) as exc:
conn.channel()
assert exc == 'You need to provide servicebus namespace'
def test_queue_service_nocredentials():
conn = Connection(URL_NOCREDS, transport=azureservicebus.Transport)
with pytest.raises(ValueError) as exc:
conn.channel()
assert exc == 'Need an URI like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}'
# Test getting queue service when queue_service is not setted
with patch('kombu.transport.azureservicebus.ServiceBusService') as m:
channel = conn.channel()
# Remove queue service to get from service bus again
channel._queue_service = None
channel.queue_service
def test_queue_service():
# Test gettings queue service without credentials
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
with patch('kombu.transport.azureservicebus.ServiceBusClient') as m:
channel = conn.channel()
assert m.call_count == 2
# Check the SAS token "ke/y" has been parsed from the url correctly
assert channel._sas_key == 'ke/y'
# Calling queue_service again needs to reuse ServiceBus instance
channel.queue_service
assert m.call_count == 2
m.from_connection_string.return_value = 'test'
def test_conninfo(self):
conninfo = self.channel.conninfo
assert conninfo is self.connection
# Remove queue service to get from service bus again
channel._queue_service = None
assert channel.queue_service == 'test'
assert m.from_connection_string.call_count == 1
def test_transport_type(self):
transport_options = self.channel.transport_options
assert transport_options == {}
# Ensure that queue_service is cached
assert channel.queue_service == 'test'
assert m.from_connection_string.call_count == 1
def test_visibility_timeout(self):
# Test getting default visibility timeout
assert (
self.channel.visibility_timeout ==
azureservicebus.Channel.default_visibility_timeout
)
# Test getting value setted in transport options
del self.channel.visibility_timeout
self.channel.transport_options['visibility_timeout'] = 10
assert self.channel.visibility_timeout == 10
def test_conninfo():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
channel = conn.channel()
assert channel.conninfo is conn
def test_wait_timeout_seconds(self):
# Test getting default wait timeout seconds
assert (
self.channel.wait_time_seconds ==
azureservicebus.Channel.default_wait_time_seconds
)
# Test getting value setted in transport options
del self.channel.wait_time_seconds
self.channel.transport_options['wait_time_seconds'] = 10
assert self.channel.wait_time_seconds == 10
def test_transport_type():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
channel = conn.channel()
assert not channel.transport_options
def test_peek_lock(self):
# Test getting default peek lock
assert (
self.channel.peek_lock ==
azureservicebus.Channel.default_peek_lock
)
# Test getting value setted in transport options
del self.channel.peek_lock
self.channel.transport_options['peek_lock'] = True
assert self.channel.peek_lock is True
def test_default_wait_timeout_seconds():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
channel = conn.channel()
def test_get_from_azure(self):
# Test getting a single message
message = 'my test message'
self.producer.publish(message)
result = self.channel._get(self.queue_name)
assert 'body' in result.keys()
assert channel.wait_time_seconds == azureservicebus.Channel.default_wait_time_seconds
# Test getting multiple messages
for i in range(3):
message = f'message: {i}'
self.producer.publish(message)
queue_service = self.channel.queue_service
assert len(queue_service.get_queue(self.queue_name).messages) == 3
def test_custom_wait_timeout_seconds():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport, transport_options={'wait_time_seconds': 10})
channel = conn.channel()
for i in range(3):
result = self.channel._get(self.queue_name)
assert channel.wait_time_seconds == 10
assert len(queue_service.get_queue(self.queue_name).messages) == 0
def test_get_with_empty_list(self):
with pytest.raises(Empty):
self.channel._get(self.queue_name)
def test_default_peek_lock_seconds():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
channel = conn.channel()
def test_put_and_get(self):
message = 'my test message'
self.producer.publish(message)
results = self.queue(self.channel).get().payload
assert message == results
assert channel.peek_lock_seconds == azureservicebus.Channel.default_peek_lock_seconds
def test_delete_queue(self):
# Test deleting queue without message
queue_name = 'new_unittest_queue'
self.channel._new_queue(queue_name)
assert queue_name in self.channel._queue_cache
self.channel._delete(queue_name)
assert queue_name not in self.channel._queue_cache
def test_custom_peek_lock_seconds():
conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
transport_options={'peek_lock_seconds': 65})
channel = conn.channel()
# Test deleting queue with message
message = 'my test message'
self.producer.publish(message)
self.channel._delete(self.queue_name)
assert queue_name not in self.channel._queue_cache
assert channel.peek_lock_seconds == 65
def test_invalid_peek_lock_seconds():
# Max is 300
conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
transport_options={'peek_lock_seconds': 900})
channel = conn.channel()
assert channel.peek_lock_seconds == 300
@pytest.fixture
def random_queue():
return 'azureservicebus_queue_{0}'.format(random.randint(1000,9999))
@pytest.fixture
def mock_asb():
return ASBMock()
@pytest.fixture
def mock_asb_management(mock_asb):
return ASBMgmtMock(queues=mock_asb.queues)
MockQueue = namedtuple('MockQueue', ['queue_name', 'asb', 'asb_mgmt', 'conn', 'channel', 'producer', 'queue'])
@pytest.fixture
def mock_queue(mock_asb, mock_asb_management, random_queue) -> MockQueue:
exchange = Exchange('test_servicebus', type='direct')
queue = Queue(random_queue, exchange, random_queue)
conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
channel = conn.channel()
channel._queue_service = mock_asb
channel._queue_mgmt_service = mock_asb_management
queue(channel).declare()
producer = messaging.Producer(channel, exchange, routing_key=random_queue)
return MockQueue(
random_queue,
mock_asb,
mock_asb_management,
conn,
channel,
producer,
queue
)
def test_basic_put_get(mock_queue: MockQueue):
text_message = "test message"
# This ends up hitting channel._put
mock_queue.producer.publish(text_message)
assert len(mock_queue.asb.queues[mock_queue.queue_name].items) == 1
azure_msg = mock_queue.asb.queues[mock_queue.queue_name].items[0]
assert isinstance(azure_msg, ServiceBusMessage)
message = mock_queue.channel._get(mock_queue.queue_name)
azure_msg_decoded = json.loads(str(azure_msg))
assert message['body'] == azure_msg_decoded['body']
# Check the message has been annotated with the azure message object
# which is used to ack later
assert message['properties']['delivery_info']['azure_message'] is azure_msg
assert base64.b64decode(message['body']).decode() == text_message
# Ack is on by default, check an ack is waiting
assert len(mock_queue.asb.queues[mock_queue.queue_name].waiting_ack) == 1
def test_empty_queue_get(mock_queue: MockQueue):
with pytest.raises(Empty):
mock_queue.channel._get(mock_queue.queue_name)
def test_delete_empty_queue(mock_queue: MockQueue):
chan = mock_queue.channel
queue_name = 'random_queue_{0}'.format(random.randint(1000, 9999))
chan._new_queue(queue_name)
assert queue_name in chan._queue_cache
chan._delete(queue_name)
assert queue_name not in chan._queue_cache
def test_delete_populated_queue(mock_queue: MockQueue):
mock_queue.producer.publish('test1234')
mock_queue.channel._delete(mock_queue.queue_name)
assert mock_queue.queue_name not in mock_queue.channel._queue_cache
def test_purge(mock_queue: MockQueue):
mock_queue.producer.publish('test1234')
mock_queue.producer.publish('test1234')
mock_queue.producer.publish('test1234')
mock_queue.producer.publish('test1234')
size = mock_queue.channel._size(mock_queue.queue_name)
assert size == 4
assert mock_queue.channel._purge(mock_queue.queue_name) == 4
size = mock_queue.channel._size(mock_queue.queue_name)
assert size == 0
assert len(mock_queue.asb.queues[mock_queue.queue_name].waiting_ack) == 0