mirror of https://github.com/celery/kombu.git
812 lines
29 KiB
Python
812 lines
29 KiB
Python
"""Testing module for the kombu.transport.SQS package.
|
|
|
|
NOTE: The SQSQueueMock and SQSConnectionMock classes originally come from
|
|
http://github.com/pcsforeducation/sqs-mock-python. They have been patched
|
|
slightly.
|
|
"""
|
|
import base64
|
|
import os
|
|
from datetime import datetime, timedelta
|
|
|
|
import pytest
|
|
import random
|
|
import string
|
|
from queue import Empty
|
|
|
|
from unittest.mock import Mock, patch
|
|
|
|
from kombu import messaging
|
|
from kombu import Connection, Exchange, Queue
|
|
|
|
boto3 = pytest.importorskip('boto3')
|
|
|
|
from kombu.transport import SQS # noqa
|
|
from botocore.exceptions import ClientError # noqa
|
|
|
|
SQS_Channel_sqs = SQS.Channel.sqs
|
|
|
|
|
|
example_predefined_queues = {
|
|
'queue-1': {
|
|
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-1',
|
|
'access_key_id': 'a',
|
|
'secret_access_key': 'b',
|
|
'backoff_tasks': ['svc.tasks.tasks.task1'],
|
|
'backoff_policy': {1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640}
|
|
},
|
|
'queue-2': {
|
|
'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-2',
|
|
'access_key_id': 'c',
|
|
'secret_access_key': 'd',
|
|
},
|
|
}
|
|
|
|
|
|
class SQSMessageMock:
|
|
def __init__(self):
|
|
"""
|
|
Imitate the SQS Message from boto3.
|
|
"""
|
|
self.body = ""
|
|
self.receipt_handle = "receipt_handle_xyz"
|
|
|
|
|
|
class QueueMock:
|
|
""" Hold information about a queue. """
|
|
|
|
def __init__(self, url, creation_attributes=None):
|
|
self.url = url
|
|
# arguments of boto3.sqs.create_queue
|
|
self.creation_attributes = creation_attributes
|
|
self.attributes = {'ApproximateNumberOfMessages': '0'}
|
|
|
|
self.messages = []
|
|
|
|
def __repr__(self):
|
|
return 'QueueMock: {} {} messages'.format(self.url, len(self.messages))
|
|
|
|
|
|
class SQSClientMock:
|
|
|
|
def __init__(self, QueueName='unittest_queue'):
|
|
"""
|
|
Imitate the SQS Client from boto3.
|
|
"""
|
|
self._receive_messages_calls = 0
|
|
# _queues doesn't exist on the real client, here for testing.
|
|
self._queues = {}
|
|
url = self.create_queue(QueueName=QueueName)['QueueUrl']
|
|
self.send_message(QueueUrl=url, MessageBody='hello')
|
|
|
|
def _get_q(self, url):
|
|
""" Helper method to quickly get a queue. """
|
|
for q in self._queues.values():
|
|
if q.url == url:
|
|
return q
|
|
raise Exception(f"Queue url {url} not found")
|
|
|
|
def create_queue(self, QueueName=None, Attributes=None):
|
|
q = self._queues[QueueName] = QueueMock(
|
|
'https://sqs.us-east-1.amazonaws.com/xxx/' + QueueName,
|
|
Attributes,
|
|
)
|
|
return {'QueueUrl': q.url}
|
|
|
|
def list_queues(self, QueueNamePrefix=None):
|
|
""" Return a list of queue urls """
|
|
urls = (val.url for key, val in self._queues.items()
|
|
if key.startswith(QueueNamePrefix))
|
|
return {'QueueUrls': urls}
|
|
|
|
def get_queue_url(self, QueueName=None):
|
|
return self._queues[QueueName]
|
|
|
|
def send_message(self, QueueUrl=None, MessageBody=None):
|
|
for q in self._queues.values():
|
|
if q.url == QueueUrl:
|
|
handle = ''.join(random.choice(string.ascii_lowercase) for
|
|
x in range(10))
|
|
q.messages.append({'Body': MessageBody,
|
|
'ReceiptHandle': handle})
|
|
break
|
|
|
|
def receive_message(self, QueueUrl=None, MaxNumberOfMessages=1,
|
|
WaitTimeSeconds=10):
|
|
self._receive_messages_calls += 1
|
|
for q in self._queues.values():
|
|
if q.url == QueueUrl:
|
|
msgs = q.messages[:MaxNumberOfMessages]
|
|
q.messages = q.messages[MaxNumberOfMessages:]
|
|
return {'Messages': msgs} if msgs else {}
|
|
|
|
def get_queue_attributes(self, QueueUrl=None, AttributeNames=None):
|
|
if 'ApproximateNumberOfMessages' in AttributeNames:
|
|
count = len(self._get_q(QueueUrl).messages)
|
|
return {'Attributes': {'ApproximateNumberOfMessages': count}}
|
|
|
|
def purge_queue(self, QueueUrl=None):
|
|
for q in self._queues.values():
|
|
if q.url == QueueUrl:
|
|
q.messages = []
|
|
|
|
|
|
class test_Channel:
|
|
|
|
def handleMessageCallback(self, message):
|
|
self.callback_message = message
|
|
|
|
def setup(self):
|
|
"""Mock the back-end SQS classes"""
|
|
# Sanity check... if SQS is None, then it did not import and we
|
|
# cannot execute our tests.
|
|
SQS.Channel._queue_cache.clear()
|
|
|
|
# Common variables used in the unit tests
|
|
self.queue_name = 'unittest'
|
|
|
|
# Mock the sqs() method that returns an SQSConnection object and
|
|
# instead return an SQSConnectionMock() object.
|
|
sqs_conn_mock = SQSClientMock()
|
|
self.sqs_conn_mock = sqs_conn_mock
|
|
|
|
predefined_queues_sqs_conn_mocks = {
|
|
'queue-1': SQSClientMock(QueueName='queue-1'),
|
|
'queue-2': SQSClientMock(QueueName='queue-2'),
|
|
}
|
|
|
|
def mock_sqs():
|
|
def sqs(self, queue=None):
|
|
if queue in predefined_queues_sqs_conn_mocks:
|
|
return predefined_queues_sqs_conn_mocks[queue]
|
|
return sqs_conn_mock
|
|
|
|
return sqs
|
|
|
|
SQS.Channel.sqs = mock_sqs()
|
|
|
|
# Set up a task exchange for passing tasks through the queue
|
|
self.exchange = Exchange('test_SQS', type='direct')
|
|
self.queue = Queue(self.queue_name, self.exchange, self.queue_name)
|
|
|
|
# Mock up a test SQS Queue with the QueueMock class (and always
|
|
# make sure its a clean empty queue)
|
|
self.sqs_queue_mock = QueueMock('sqs://' + self.queue_name)
|
|
|
|
# Now, create our Connection object with the SQS Transport and store
|
|
# the connection/channel objects as references for use in these tests.
|
|
self.connection = Connection(transport=SQS.Transport)
|
|
self.channel = self.connection.channel()
|
|
|
|
self.queue(self.channel).declare()
|
|
self.producer = messaging.Producer(self.channel,
|
|
self.exchange,
|
|
routing_key=self.queue_name)
|
|
|
|
# Lastly, make sure that we're set up to 'consume' this queue.
|
|
self.channel.basic_consume(self.queue_name,
|
|
no_ack=False,
|
|
callback=self.handleMessageCallback,
|
|
consumer_tag='unittest')
|
|
|
|
def teardown(self):
|
|
# Removes QoS reserved messages so we don't restore msgs on shutdown.
|
|
try:
|
|
qos = self.channel._qos
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
if qos:
|
|
qos._dirty.clear()
|
|
qos._delivered.clear()
|
|
|
|
def test_init(self):
|
|
"""kombu.SQS.Channel instantiates correctly with mocked queues"""
|
|
assert self.queue_name in self.channel._queue_cache
|
|
|
|
def test_region(self):
|
|
_environ = dict(os.environ)
|
|
|
|
# when the region is unspecified
|
|
connection = Connection(transport=SQS.Transport)
|
|
channel = connection.channel()
|
|
assert channel.transport_options.get('region') is None
|
|
# the default region is us-east-1
|
|
assert channel.region == 'us-east-1'
|
|
|
|
# when boto3 picks a region
|
|
os.environ['AWS_DEFAULT_REGION'] = 'us-east-2'
|
|
assert boto3.Session().region_name == 'us-east-2'
|
|
# the default region should match
|
|
connection = Connection(transport=SQS.Transport)
|
|
channel = connection.channel()
|
|
assert channel.region == 'us-east-2'
|
|
|
|
# when transport_options are provided
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'region': 'us-west-2'
|
|
})
|
|
channel = connection.channel()
|
|
assert channel.transport_options.get('region') == 'us-west-2'
|
|
# the specified region should be used
|
|
assert connection.channel().region == 'us-west-2'
|
|
|
|
os.environ.clear()
|
|
os.environ.update(_environ)
|
|
|
|
def test_endpoint_url(self):
|
|
url = 'sqs://@localhost:5493'
|
|
self.connection = Connection(hostname=url, transport=SQS.Transport)
|
|
self.channel = self.connection.channel()
|
|
self.channel._sqs = None
|
|
expected_endpoint_url = 'http://localhost:5493'
|
|
assert self.channel.endpoint_url == expected_endpoint_url
|
|
boto3_sqs = SQS_Channel_sqs.__get__(self.channel, SQS.Channel)
|
|
assert boto3_sqs()._endpoint.host == expected_endpoint_url
|
|
|
|
def test_none_hostname_persists(self):
|
|
conn = Connection(hostname=None, transport=SQS.Transport)
|
|
assert conn.hostname == conn.clone().hostname
|
|
|
|
def test_entity_name(self):
|
|
assert self.channel.entity_name('foo') == 'foo'
|
|
assert self.channel.entity_name('foo.bar-baz*qux_quux') == \
|
|
'foo-bar-baz_qux_quux'
|
|
assert self.channel.entity_name('abcdef.fifo') == 'abcdef.fifo'
|
|
|
|
def test_new_queue(self):
|
|
queue_name = 'new_unittest_queue'
|
|
self.channel._new_queue(queue_name)
|
|
assert queue_name in self.sqs_conn_mock._queues.keys()
|
|
# For cleanup purposes, delete the queue and the queue file
|
|
self.channel._delete(queue_name)
|
|
|
|
def test_new_queue_custom_creation_attributes(self):
|
|
self.connection.transport_options['sqs-creation-attributes'] = {
|
|
'KmsMasterKeyId': 'alias/aws/sqs',
|
|
}
|
|
queue_name = 'new_custom_attribute_queue'
|
|
self.channel._new_queue(queue_name)
|
|
|
|
assert queue_name in self.sqs_conn_mock._queues.keys()
|
|
queue = self.sqs_conn_mock._queues[queue_name]
|
|
|
|
assert 'KmsMasterKeyId' in queue.creation_attributes
|
|
assert queue.creation_attributes['KmsMasterKeyId'] == 'alias/aws/sqs'
|
|
|
|
# For cleanup purposes, delete the queue and the queue file
|
|
self.channel._delete(queue_name)
|
|
|
|
def test_botocore_config_override(self):
|
|
expected_connect_timeout = 5
|
|
client_config = {'connect_timeout': expected_connect_timeout}
|
|
self.connection = Connection(
|
|
transport=SQS.Transport,
|
|
transport_options={'client-config': client_config},
|
|
)
|
|
self.channel = self.connection.channel()
|
|
self.channel._sqs = None
|
|
boto3_sqs = SQS_Channel_sqs.__get__(self.channel, SQS.Channel)
|
|
botocore_config = boto3_sqs()._client_config
|
|
assert botocore_config.connect_timeout == expected_connect_timeout
|
|
|
|
def test_dont_create_duplicate_new_queue(self):
|
|
# All queue names start with "q", except "unittest_queue".
|
|
# which is definitely out of cache when get_all_queues returns the
|
|
# first 1000 queues sorted by name.
|
|
queue_name = 'unittest_queue'
|
|
# This should not create a new queue.
|
|
self.channel._new_queue(queue_name)
|
|
assert queue_name in self.sqs_conn_mock._queues.keys()
|
|
queue = self.sqs_conn_mock._queues[queue_name]
|
|
# The queue originally had 1 message in it.
|
|
assert 1 == len(queue.messages)
|
|
assert 'hello' == queue.messages[0]['Body']
|
|
|
|
def test_delete(self):
|
|
queue_name = 'new_unittest_queue'
|
|
self.channel._new_queue(queue_name)
|
|
self.channel._delete(queue_name)
|
|
assert queue_name not in self.channel._queue_cache
|
|
|
|
def test_get_from_sqs(self):
|
|
# Test getting a single message
|
|
message = 'my test message'
|
|
self.producer.publish(message)
|
|
result = self.channel._get(self.queue_name)
|
|
assert 'body' in result.keys()
|
|
|
|
# Now test getting many messages
|
|
for i in range(3):
|
|
message = f'message: {i}'
|
|
self.producer.publish(message)
|
|
|
|
self.channel._get_bulk(self.queue_name, max_if_unlimited=3)
|
|
assert len(self.sqs_conn_mock._queues[self.queue_name].messages) == 0
|
|
|
|
def test_get_with_empty_list(self):
|
|
with pytest.raises(Empty):
|
|
self.channel._get(self.queue_name)
|
|
|
|
def test_get_bulk_raises_empty(self):
|
|
with pytest.raises(Empty):
|
|
self.channel._get_bulk(self.queue_name)
|
|
|
|
def test_is_base64_encoded(self):
|
|
raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \
|
|
b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' # noqa
|
|
b64_enc = base64.b64encode(raw)
|
|
assert self.channel._Channel__b64_encoded(b64_enc)
|
|
assert not self.channel._Channel__b64_encoded(raw)
|
|
assert not self.channel._Channel__b64_encoded(b"test123")
|
|
|
|
def test_messages_to_python(self):
|
|
from kombu.asynchronous.aws.sqs.message import Message
|
|
|
|
kombu_message_count = 3
|
|
json_message_count = 3
|
|
# Create several test messages and publish them
|
|
for i in range(kombu_message_count):
|
|
message = 'message: %s' % i
|
|
self.producer.publish(message)
|
|
|
|
# json formatted message NOT created by kombu
|
|
for i in range(json_message_count):
|
|
message = {'foo': 'bar'}
|
|
self.channel._put(self.producer.routing_key, message)
|
|
|
|
q_url = self.channel._new_queue(self.queue_name)
|
|
# Get the messages now
|
|
kombu_messages = []
|
|
for m in self.sqs_conn_mock.receive_message(
|
|
QueueUrl=q_url,
|
|
MaxNumberOfMessages=kombu_message_count)['Messages']:
|
|
m['Body'] = Message(body=m['Body']).decode()
|
|
kombu_messages.append(m)
|
|
json_messages = []
|
|
for m in self.sqs_conn_mock.receive_message(
|
|
QueueUrl=q_url,
|
|
MaxNumberOfMessages=json_message_count)['Messages']:
|
|
m['Body'] = Message(body=m['Body']).decode()
|
|
json_messages.append(m)
|
|
|
|
# Now convert them to payloads
|
|
kombu_payloads = self.channel._messages_to_python(
|
|
kombu_messages, self.queue_name,
|
|
)
|
|
json_payloads = self.channel._messages_to_python(
|
|
json_messages, self.queue_name,
|
|
)
|
|
|
|
# We got the same number of payloads back, right?
|
|
assert len(kombu_payloads) == kombu_message_count
|
|
assert len(json_payloads) == json_message_count
|
|
|
|
# Make sure they're payload-style objects
|
|
for p in kombu_payloads:
|
|
assert 'properties' in p
|
|
for p in json_payloads:
|
|
assert 'properties' in p
|
|
|
|
def test_put_and_get(self):
|
|
message = 'my test message'
|
|
self.producer.publish(message)
|
|
results = self.queue(self.channel).get().payload
|
|
assert message == results
|
|
|
|
def test_redelivered(self):
|
|
self.channel.sqs().change_message_visibility = \
|
|
Mock(name='change_message_visibility')
|
|
message = {
|
|
'redelivered': True,
|
|
'properties': {'delivery_tag': 'test_message_id'}
|
|
}
|
|
self.channel._put(self.producer.routing_key, message)
|
|
self.sqs_conn_mock.change_message_visibility.assert_called_once()
|
|
|
|
def test_put_and_get_bulk(self):
|
|
# With QoS.prefetch_count = 0
|
|
message = 'my test message'
|
|
self.producer.publish(message)
|
|
self.channel.connection._deliver = Mock(name='_deliver')
|
|
self.channel._get_bulk(self.queue_name)
|
|
self.channel.connection._deliver.assert_called_once()
|
|
|
|
def test_puts_and_get_bulk(self):
|
|
# Generate 8 messages
|
|
message_count = 8
|
|
|
|
# Set the prefetch_count to 5
|
|
self.channel.qos.prefetch_count = 5
|
|
|
|
# Now, generate all the messages
|
|
for i in range(message_count):
|
|
message = 'message: %s' % i
|
|
self.producer.publish(message)
|
|
|
|
# Count how many messages are retrieved the first time. Should
|
|
# be 5 (message_count).
|
|
self.channel.connection._deliver = Mock(name='_deliver')
|
|
self.channel._get_bulk(self.queue_name)
|
|
assert self.channel.connection._deliver.call_count == 5
|
|
for i in range(5):
|
|
self.channel.qos.append(Mock(name=f'message{i}'), i)
|
|
|
|
# Now, do the get again, the number of messages returned should be 1.
|
|
self.channel.connection._deliver.reset_mock()
|
|
self.channel._get_bulk(self.queue_name)
|
|
self.channel.connection._deliver.assert_called_once()
|
|
|
|
def test_drain_events_with_empty_list(self):
|
|
def mock_can_consume():
|
|
return False
|
|
self.channel.qos.can_consume = mock_can_consume
|
|
with pytest.raises(Empty):
|
|
self.channel.drain_events()
|
|
|
|
def test_drain_events_with_prefetch_5(self):
|
|
# Generate 20 messages
|
|
message_count = 20
|
|
prefetch_count = 5
|
|
|
|
current_delivery_tag = [1]
|
|
|
|
# Set the prefetch_count to 5
|
|
self.channel.qos.prefetch_count = prefetch_count
|
|
self.channel.connection._deliver = Mock(name='_deliver')
|
|
|
|
def on_message_delivered(message, queue):
|
|
current_delivery_tag[0] += 1
|
|
self.channel.qos.append(message, current_delivery_tag[0])
|
|
self.channel.connection._deliver.side_effect = on_message_delivered
|
|
|
|
# Now, generate all the messages
|
|
for i in range(message_count):
|
|
self.producer.publish('message: %s' % i)
|
|
|
|
# Now drain all the events
|
|
for i in range(1000):
|
|
try:
|
|
self.channel.drain_events(timeout=0)
|
|
except Empty:
|
|
break
|
|
else:
|
|
assert False, 'disabled infinite loop'
|
|
|
|
self.channel.qos._flush()
|
|
assert len(self.channel.qos._delivered) == prefetch_count
|
|
|
|
assert self.channel.connection._deliver.call_count == prefetch_count
|
|
|
|
def test_drain_events_with_prefetch_none(self):
|
|
# Generate 20 messages
|
|
message_count = 20
|
|
expected_receive_messages_count = 3
|
|
|
|
current_delivery_tag = [1]
|
|
|
|
# Set the prefetch_count to None
|
|
self.channel.qos.prefetch_count = None
|
|
self.channel.connection._deliver = Mock(name='_deliver')
|
|
|
|
def on_message_delivered(message, queue):
|
|
current_delivery_tag[0] += 1
|
|
self.channel.qos.append(message, current_delivery_tag[0])
|
|
self.channel.connection._deliver.side_effect = on_message_delivered
|
|
|
|
# Now, generate all the messages
|
|
for i in range(message_count):
|
|
self.producer.publish('message: %s' % i)
|
|
|
|
# Now drain all the events
|
|
for i in range(1000):
|
|
try:
|
|
self.channel.drain_events(timeout=0)
|
|
except Empty:
|
|
break
|
|
else:
|
|
assert False, 'disabled infinite loop'
|
|
|
|
assert self.channel.connection._deliver.call_count == message_count
|
|
|
|
# How many times was the SQSConnectionMock receive_message method
|
|
# called?
|
|
assert (expected_receive_messages_count ==
|
|
self.sqs_conn_mock._receive_messages_calls)
|
|
|
|
def test_basic_ack(self, ):
|
|
"""Test that basic_ack calls the delete_message properly"""
|
|
message = {
|
|
'sqs_message': {
|
|
'ReceiptHandle': '1'
|
|
},
|
|
'sqs_queue': 'testing_queue'
|
|
}
|
|
mock_messages = Mock()
|
|
mock_messages.delivery_info = message
|
|
self.channel.qos.append(mock_messages, 1)
|
|
self.channel.sqs().delete_message = Mock()
|
|
self.channel.basic_ack(1)
|
|
self.sqs_conn_mock.delete_message.assert_called_with(
|
|
QueueUrl=message['sqs_queue'],
|
|
ReceiptHandle=message['sqs_message']['ReceiptHandle']
|
|
)
|
|
assert {1} == self.channel.qos._dirty
|
|
|
|
@patch('kombu.transport.virtual.base.Channel.basic_ack')
|
|
@patch('kombu.transport.virtual.base.Channel.basic_reject')
|
|
def test_basic_ack_with_mocked_channel_methods(self, basic_reject_mock,
|
|
basic_ack_mock):
|
|
"""Test that basic_ack calls the delete_message properly"""
|
|
message = {
|
|
'sqs_message': {
|
|
'ReceiptHandle': '1'
|
|
},
|
|
'sqs_queue': 'testing_queue'
|
|
}
|
|
mock_messages = Mock()
|
|
mock_messages.delivery_info = message
|
|
self.channel.qos.append(mock_messages, 1)
|
|
self.channel.sqs().delete_message = Mock()
|
|
self.channel.basic_ack(1)
|
|
self.sqs_conn_mock.delete_message.assert_called_with(
|
|
QueueUrl=message['sqs_queue'],
|
|
ReceiptHandle=message['sqs_message']['ReceiptHandle']
|
|
)
|
|
basic_ack_mock.assert_called_with(1)
|
|
assert not basic_reject_mock.called
|
|
|
|
@patch('kombu.transport.virtual.base.Channel.basic_ack')
|
|
@patch('kombu.transport.virtual.base.Channel.basic_reject')
|
|
def test_basic_ack_without_sqs_message(self, basic_reject_mock,
|
|
basic_ack_mock):
|
|
"""Test that basic_ack calls the delete_message properly"""
|
|
message = {
|
|
'sqs_queue': 'testing_queue'
|
|
}
|
|
mock_messages = Mock()
|
|
mock_messages.delivery_info = message
|
|
self.channel.qos.append(mock_messages, 1)
|
|
self.channel.sqs().delete_message = Mock()
|
|
self.channel.basic_ack(1)
|
|
assert not self.sqs_conn_mock.delete_message.called
|
|
basic_ack_mock.assert_called_with(1)
|
|
assert not basic_reject_mock.called
|
|
|
|
@patch('kombu.transport.virtual.base.Channel.basic_ack')
|
|
@patch('kombu.transport.virtual.base.Channel.basic_reject')
|
|
def test_basic_ack_invalid_receipt_handle(self, basic_reject_mock,
|
|
basic_ack_mock):
|
|
"""Test that basic_ack calls the delete_message properly"""
|
|
message = {
|
|
'sqs_message': {
|
|
'ReceiptHandle': '2'
|
|
},
|
|
'sqs_queue': 'testing_queue'
|
|
}
|
|
error_response = {
|
|
'Error': {
|
|
'Code': 'InvalidParameterValue',
|
|
'Message': 'Value 2 for parameter ReceiptHandle is invalid.'
|
|
' Reason: The receipt handle has expired.'
|
|
}
|
|
}
|
|
operation_name = 'DeleteMessage'
|
|
|
|
mock_messages = Mock()
|
|
mock_messages.delivery_info = message
|
|
self.channel.qos.append(mock_messages, 2)
|
|
self.channel.sqs().delete_message = Mock()
|
|
self.channel.sqs().delete_message.side_effect = ClientError(
|
|
error_response=error_response,
|
|
operation_name=operation_name
|
|
)
|
|
self.channel.basic_ack(2)
|
|
self.sqs_conn_mock.delete_message.assert_called_with(
|
|
QueueUrl=message['sqs_queue'],
|
|
ReceiptHandle=message['sqs_message']['ReceiptHandle']
|
|
)
|
|
basic_reject_mock.assert_called_with(2)
|
|
assert not basic_ack_mock.called
|
|
|
|
def test_predefined_queues_primes_queue_cache(self):
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
})
|
|
channel = connection.channel()
|
|
|
|
assert 'queue-1' in channel._queue_cache
|
|
assert 'queue-2' in channel._queue_cache
|
|
|
|
def test_predefined_queues_new_queue_raises_if_queue_not_exists(self):
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
})
|
|
channel = connection.channel()
|
|
|
|
with pytest.raises(SQS.UndefinedQueueException):
|
|
channel._new_queue('queue-99')
|
|
|
|
def test_predefined_queues_get_from_sqs(self):
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
})
|
|
channel = connection.channel()
|
|
|
|
def message_to_python(message, queue_name, queue):
|
|
return message
|
|
|
|
channel._message_to_python = Mock(side_effect=message_to_python)
|
|
|
|
queue_name = "queue-1"
|
|
|
|
exchange = Exchange('test_SQS', type='direct')
|
|
p = messaging.Producer(channel, exchange, routing_key=queue_name)
|
|
queue = Queue(queue_name, exchange, queue_name)
|
|
queue(channel).declare()
|
|
|
|
# Getting a single message
|
|
p.publish('message')
|
|
result = channel._get(queue_name)
|
|
|
|
assert 'Body' in result.keys()
|
|
|
|
# Getting many messages
|
|
for i in range(3):
|
|
p.publish(f'message: {i}')
|
|
|
|
channel.connection._deliver = Mock(name='_deliver')
|
|
channel._get_bulk(queue_name, max_if_unlimited=3)
|
|
channel.connection._deliver.assert_called()
|
|
|
|
assert len(channel.sqs(queue_name)._queues[queue_name].messages) == 0
|
|
|
|
def test_predefined_queues_backoff_policy(self):
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
})
|
|
channel = connection.channel()
|
|
|
|
def apply_backoff_policy(
|
|
queue_name, delivery_tag, retry_policy, backoff_tasks):
|
|
return None
|
|
|
|
mock_apply_policy = Mock(side_effect=apply_backoff_policy)
|
|
channel.qos.apply_backoff_policy = mock_apply_policy
|
|
queue_name = "queue-1"
|
|
|
|
exchange = Exchange('test_SQS', type='direct')
|
|
queue = Queue(queue_name, exchange, queue_name)
|
|
queue(channel).declare()
|
|
|
|
message_mock = Mock()
|
|
message_mock.delivery_info = {'routing_key': queue_name}
|
|
channel.qos._delivered['test_message_id'] = message_mock
|
|
channel.qos.reject('test_message_id')
|
|
mock_apply_policy.assert_called_once_with(
|
|
'queue-1', 'test_message_id',
|
|
{1: 10, 2: 20, 3: 40, 4: 80, 5: 320, 6: 640},
|
|
['svc.tasks.tasks.task1']
|
|
)
|
|
|
|
def test_predefined_queues_change_visibility_timeout(self):
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
})
|
|
channel = connection.channel()
|
|
|
|
def extract_task_name_and_number_of_retries(delivery_tag):
|
|
return 'svc.tasks.tasks.task1', 2
|
|
|
|
mock_extract_task_name_and_number_of_retries = Mock(
|
|
side_effect=extract_task_name_and_number_of_retries)
|
|
channel.qos.extract_task_name_and_number_of_retries = \
|
|
mock_extract_task_name_and_number_of_retries
|
|
|
|
queue_name = "queue-1"
|
|
|
|
exchange = Exchange('test_SQS', type='direct')
|
|
queue = Queue(queue_name, exchange, queue_name)
|
|
queue(channel).declare()
|
|
|
|
message_mock = Mock()
|
|
message_mock.delivery_info = {'routing_key': queue_name}
|
|
channel.qos._delivered['test_message_id'] = message_mock
|
|
|
|
channel.sqs = Mock()
|
|
sqs_queue_mock = Mock()
|
|
channel.sqs.return_value = sqs_queue_mock
|
|
channel.qos.reject('test_message_id')
|
|
|
|
sqs_queue_mock.change_message_visibility.assert_called_once_with(
|
|
QueueUrl='https://sqs.us-east-1.amazonaws.com/xxx/queue-1',
|
|
ReceiptHandle='test_message_id', VisibilityTimeout=20)
|
|
|
|
def test_sts_new_session(self):
|
|
# Arrange
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
'sts_role_arn': 'test::arn'
|
|
})
|
|
channel = connection.channel()
|
|
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
|
|
queue_name = 'queue-1'
|
|
|
|
mock_generate_sts_session_token = Mock()
|
|
mock_new_sqs_client = Mock()
|
|
channel.new_sqs_client = mock_new_sqs_client
|
|
mock_generate_sts_session_token.side_effect = [
|
|
{
|
|
'Expiration': 123,
|
|
'SessionToken': 123,
|
|
'AccessKeyId': 123,
|
|
'SecretAccessKey': 123
|
|
}
|
|
]
|
|
channel.generate_sts_session_token = mock_generate_sts_session_token
|
|
|
|
# Act
|
|
sqs(queue=queue_name)
|
|
|
|
# Assert
|
|
mock_generate_sts_session_token.assert_called_once()
|
|
|
|
def test_sts_session_expired(self):
|
|
# Arrange
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
'sts_role_arn': 'test::arn'
|
|
})
|
|
channel = connection.channel()
|
|
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
|
|
channel.sts_expiration = datetime.utcnow() - timedelta(days=1)
|
|
queue_name = 'queue-1'
|
|
|
|
mock_generate_sts_session_token = Mock()
|
|
mock_new_sqs_client = Mock()
|
|
channel.new_sqs_client = mock_new_sqs_client
|
|
mock_generate_sts_session_token.side_effect = [
|
|
{
|
|
'Expiration': 123,
|
|
'SessionToken': 123,
|
|
'AccessKeyId': 123,
|
|
'SecretAccessKey': 123
|
|
}
|
|
]
|
|
channel.generate_sts_session_token = mock_generate_sts_session_token
|
|
|
|
# Act
|
|
sqs(queue=queue_name)
|
|
|
|
# Assert
|
|
mock_generate_sts_session_token.assert_called_once()
|
|
|
|
def test_sts_session_not_expired(self):
|
|
# Arrange
|
|
connection = Connection(transport=SQS.Transport, transport_options={
|
|
'predefined_queues': example_predefined_queues,
|
|
'sts_role_arn': 'test::arn'
|
|
})
|
|
channel = connection.channel()
|
|
channel.sts_expiration = datetime.utcnow() + timedelta(days=1)
|
|
queue_name = 'queue-1'
|
|
|
|
mock_generate_sts_session_token = Mock()
|
|
mock_new_sqs_client = Mock()
|
|
channel.new_sqs_client = mock_new_sqs_client
|
|
channel._predefined_queue_clients = {queue_name: 'mock_client'}
|
|
mock_generate_sts_session_token.side_effect = [
|
|
{
|
|
'Expiration': 123,
|
|
'SessionToken': 123,
|
|
'AccessKeyId': 123,
|
|
'SecretAccessKey': 123
|
|
}
|
|
]
|
|
channel.generate_sts_session_token = mock_generate_sts_session_token
|
|
|
|
# Act
|
|
channel.sqs(queue=queue_name)
|
|
|
|
# Assert
|
|
mock_generate_sts_session_token.assert_not_called()
|