mirror of https://github.com/celery/kombu.git
519 lines
17 KiB
Python
519 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import socket
|
|
from typing import TYPE_CHECKING
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
from amqp import RecoverableConnectionError
|
|
|
|
from kombu import common
|
|
from kombu.common import (PREFETCH_COUNT_MAX, Broadcast, QoS, collect_replies,
|
|
declaration_cached, generate_oid, ignore_errors,
|
|
maybe_declare, send_reply)
|
|
from t.mocks import ContextMock, MockPool
|
|
|
|
if TYPE_CHECKING:
|
|
from types import TracebackType
|
|
|
|
|
|
def test_generate_oid():
|
|
from uuid import NAMESPACE_OID
|
|
|
|
instance = Mock()
|
|
|
|
args = (1, 1001, 2001, id(instance))
|
|
ent = '%x-%x-%x-%x' % args
|
|
|
|
with patch('kombu.common.uuid3') as mock_uuid3, \
|
|
patch('kombu.common.uuid5') as mock_uuid5:
|
|
mock_uuid3.side_effect = ValueError
|
|
mock_uuid3.return_value = 'uuid3-6ba7b812-9dad-11d1-80b4'
|
|
mock_uuid5.return_value = 'uuid5-6ba7b812-9dad-11d1-80b4'
|
|
oid = generate_oid(1, 1001, 2001, instance)
|
|
mock_uuid5.assert_called_once_with(NAMESPACE_OID, ent)
|
|
assert oid == 'uuid5-6ba7b812-9dad-11d1-80b4'
|
|
|
|
|
|
def test_ignore_errors():
|
|
connection = Mock()
|
|
connection.channel_errors = (KeyError,)
|
|
connection.connection_errors = (KeyError,)
|
|
|
|
with ignore_errors(connection):
|
|
raise KeyError()
|
|
|
|
def raising():
|
|
raise KeyError()
|
|
|
|
ignore_errors(connection, raising)
|
|
|
|
connection.channel_errors = connection.connection_errors = ()
|
|
|
|
with pytest.raises(KeyError):
|
|
with ignore_errors(connection):
|
|
raise KeyError()
|
|
|
|
|
|
class test_declaration_cached:
|
|
|
|
def test_when_cached(self):
|
|
chan = Mock()
|
|
chan.connection.client.declared_entities = ['foo']
|
|
assert declaration_cached('foo', chan)
|
|
|
|
def test_when_not_cached(self):
|
|
chan = Mock()
|
|
chan.connection.client.declared_entities = ['bar']
|
|
assert not declaration_cached('foo', chan)
|
|
|
|
|
|
class test_Broadcast:
|
|
|
|
def test_arguments(self):
|
|
with patch('kombu.common.uuid',
|
|
return_value='test') as uuid_mock:
|
|
q = Broadcast(name='test_Broadcast')
|
|
uuid_mock.assert_called_with()
|
|
assert q.name == 'bcast.test'
|
|
assert q.alias == 'test_Broadcast'
|
|
assert q.auto_delete
|
|
assert q.exchange.name == 'test_Broadcast'
|
|
assert q.exchange.type == 'fanout'
|
|
|
|
q = Broadcast('test_Broadcast', 'explicit_queue_name')
|
|
assert q.name == 'explicit_queue_name'
|
|
assert q.exchange.name == 'test_Broadcast'
|
|
|
|
q2 = q(Mock())
|
|
assert q2.name == q.name
|
|
|
|
with patch('kombu.common.uuid',
|
|
return_value='test') as uuid_mock:
|
|
q = Broadcast('test_Broadcast',
|
|
'explicit_queue_name',
|
|
unique=True)
|
|
uuid_mock.assert_called_with()
|
|
assert q.name == 'explicit_queue_name.test'
|
|
|
|
q2 = q(Mock())
|
|
assert q2.name.split('.')[0] == q.name.split('.')[0]
|
|
|
|
|
|
class test_maybe_declare:
|
|
|
|
def _get_mock_channel(self):
|
|
# Given: A mock Channel with mock'd connection/client/entities
|
|
channel = Mock()
|
|
channel.connection.client.declared_entities = set()
|
|
return channel
|
|
|
|
def _get_mock_entity(self, is_bound=False, can_cache_declaration=True):
|
|
# Given: Unbound mock Entity (will bind to channel when bind called
|
|
entity = Mock()
|
|
entity.can_cache_declaration = can_cache_declaration
|
|
entity.is_bound = is_bound
|
|
|
|
def _bind_entity(channel):
|
|
entity.channel = channel
|
|
entity.is_bound = True
|
|
return entity
|
|
entity.bind = _bind_entity
|
|
return entity
|
|
|
|
def test_cacheable(self):
|
|
# Given: A mock Channel and mock entity
|
|
channel = self._get_mock_channel()
|
|
# Given: A mock Entity that is already bound
|
|
entity = self._get_mock_entity(
|
|
is_bound=True, can_cache_declaration=True)
|
|
entity.channel = channel
|
|
entity.auto_delete = False
|
|
assert entity.is_bound, "Expected entity is bound to begin this test."
|
|
|
|
# When: Calling maybe_declare default
|
|
maybe_declare(entity, channel)
|
|
|
|
# Then: It called declare on the entity queue and added it to list
|
|
assert entity.declare.call_count == 1
|
|
assert hash(entity) in channel.connection.client.declared_entities
|
|
|
|
# When: Calling maybe_declare default (again)
|
|
maybe_declare(entity, channel)
|
|
# Then: we did not call declare again because its already in our list
|
|
assert entity.declare.call_count == 1
|
|
|
|
# When: Entity channel connection has gone away
|
|
entity.channel.connection = None
|
|
# Then: maybe_declare must raise a RecoverableConnectionError
|
|
with pytest.raises(RecoverableConnectionError):
|
|
maybe_declare(entity)
|
|
|
|
def test_binds_entities(self):
|
|
# Given: A mock Channel and mock entity
|
|
channel = self._get_mock_channel()
|
|
# Given: A mock Entity that is not bound
|
|
entity = self._get_mock_entity()
|
|
assert not entity.is_bound, "Expected entity unbound to begin test."
|
|
|
|
# When: calling maybe_declare with default of no retry policy
|
|
maybe_declare(entity, channel)
|
|
|
|
# Then: the entity is now bound because it called to bind it
|
|
assert entity.is_bound is True, "Expected entity is now marked bound."
|
|
|
|
def test_binds_entities_when_retry_policy(self):
|
|
# Given: A mock Channel and mock entity
|
|
channel = self._get_mock_channel()
|
|
# Given: A mock Entity that is not bound
|
|
entity = self._get_mock_entity()
|
|
assert not entity.is_bound, "Expected entity unbound to begin test."
|
|
|
|
# Given: A retry policy
|
|
sample_retry_policy = {
|
|
'interval_start': 0,
|
|
'interval_max': 1,
|
|
'max_retries': 3,
|
|
'interval_step': 0.2,
|
|
'errback': lambda x: "Called test errback retry policy",
|
|
}
|
|
|
|
# When: calling maybe_declare with retry enabled
|
|
maybe_declare(entity, channel, retry=True, **sample_retry_policy)
|
|
|
|
# Then: the entity is now bound because it called to bind it
|
|
assert entity.is_bound is True, "Expected entity is now marked bound."
|
|
|
|
def test_with_retry(self):
|
|
# Given: A mock Channel and mock entity
|
|
channel = self._get_mock_channel()
|
|
# Given: A mock Entity that is already bound
|
|
entity = self._get_mock_entity(
|
|
is_bound=True, can_cache_declaration=True)
|
|
entity.channel = channel
|
|
assert entity.is_bound, "Expected entity is bound to begin this test."
|
|
# When calling maybe_declare with retry enabled (default policy)
|
|
maybe_declare(entity, channel, retry=True)
|
|
# Then: the connection client used ensure to ensure the retry policy
|
|
assert channel.connection.client.ensure.call_count
|
|
|
|
def test_with_retry_dropped_connection(self):
|
|
# Given: A mock Channel and mock entity
|
|
channel = self._get_mock_channel()
|
|
# Given: A mock Entity that is already bound
|
|
entity = self._get_mock_entity(
|
|
is_bound=True, can_cache_declaration=True)
|
|
entity.channel = channel
|
|
assert entity.is_bound, "Expected entity is bound to begin this test."
|
|
# When: Entity channel connection has gone away
|
|
entity.channel.connection = None
|
|
# When: calling maybe_declare with retry
|
|
# Then: the RecoverableConnectionError should be raised
|
|
with pytest.raises(RecoverableConnectionError):
|
|
maybe_declare(entity, channel, retry=True)
|
|
|
|
|
|
class test_replies:
|
|
|
|
def test_send_reply(self):
|
|
req = Mock()
|
|
req.content_type = 'application/json'
|
|
req.content_encoding = 'binary'
|
|
req.properties = {'reply_to': 'hello',
|
|
'correlation_id': 'world'}
|
|
channel = Mock()
|
|
exchange = Mock()
|
|
exchange.is_bound = True
|
|
exchange.channel = channel
|
|
producer = Mock()
|
|
producer.channel = channel
|
|
producer.channel.connection.client.declared_entities = set()
|
|
send_reply(exchange, req, {'hello': 'world'}, producer)
|
|
|
|
assert producer.publish.call_count
|
|
args = producer.publish.call_args
|
|
assert args[0][0] == {'hello': 'world'}
|
|
assert args[1] == {
|
|
'exchange': exchange,
|
|
'routing_key': 'hello',
|
|
'correlation_id': 'world',
|
|
'serializer': 'json',
|
|
'retry': False,
|
|
'retry_policy': None,
|
|
'content_encoding': 'binary',
|
|
}
|
|
|
|
@patch('kombu.common.itermessages')
|
|
def test_collect_replies_with_ack(self, itermessages):
|
|
conn, channel, queue = Mock(), Mock(), Mock()
|
|
body, message = Mock(), Mock()
|
|
itermessages.return_value = [(body, message)]
|
|
it = collect_replies(conn, channel, queue, no_ack=False)
|
|
m = next(it)
|
|
assert m is body
|
|
itermessages.assert_called_with(conn, channel, queue, no_ack=False)
|
|
message.ack.assert_called_with()
|
|
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
channel.after_reply_message_received.assert_called_with(queue.name)
|
|
|
|
@patch('kombu.common.itermessages')
|
|
def test_collect_replies_no_ack(self, itermessages):
|
|
conn, channel, queue = Mock(), Mock(), Mock()
|
|
body, message = Mock(), Mock()
|
|
itermessages.return_value = [(body, message)]
|
|
it = collect_replies(conn, channel, queue)
|
|
m = next(it)
|
|
assert m is body
|
|
itermessages.assert_called_with(conn, channel, queue, no_ack=True)
|
|
message.ack.assert_not_called()
|
|
|
|
@patch('kombu.common.itermessages')
|
|
def test_collect_replies_no_replies(self, itermessages):
|
|
conn, channel, queue = Mock(), Mock(), Mock()
|
|
itermessages.return_value = []
|
|
it = collect_replies(conn, channel, queue)
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
channel.after_reply_message_received.assert_not_called()
|
|
|
|
|
|
class test_insured:
|
|
|
|
@patch('kombu.common.logger')
|
|
def test_ensure_errback(self, logger):
|
|
common._ensure_errback('foo', 30)
|
|
logger.error.assert_called()
|
|
|
|
def test_revive_connection(self):
|
|
on_revive = Mock()
|
|
channel = Mock()
|
|
common.revive_connection(Mock(), channel, on_revive)
|
|
on_revive.assert_called_with(channel)
|
|
|
|
common.revive_connection(Mock(), channel, None)
|
|
|
|
def get_insured_mocks(self, insured_returns=('works', 'ignored')):
|
|
conn = ContextMock()
|
|
pool = MockPool(conn)
|
|
fun = Mock()
|
|
insured = conn.autoretry.return_value = Mock()
|
|
insured.return_value = insured_returns
|
|
return conn, pool, fun, insured
|
|
|
|
def test_insured(self):
|
|
conn, pool, fun, insured = self.get_insured_mocks()
|
|
|
|
ret = common.insured(pool, fun, (2, 2), {'foo': 'bar'})
|
|
assert ret == 'works'
|
|
conn.ensure_connection.assert_called_with(
|
|
errback=common._ensure_errback,
|
|
)
|
|
|
|
insured.assert_called()
|
|
i_args, i_kwargs = insured.call_args
|
|
assert i_args == (2, 2)
|
|
assert i_kwargs == {'foo': 'bar', 'connection': conn}
|
|
|
|
conn.autoretry.assert_called()
|
|
ar_args, ar_kwargs = conn.autoretry.call_args
|
|
assert ar_args == (fun, conn.default_channel)
|
|
assert ar_kwargs.get('on_revive')
|
|
assert ar_kwargs.get('errback')
|
|
|
|
def test_insured_custom_errback(self):
|
|
conn, pool, fun, insured = self.get_insured_mocks()
|
|
|
|
custom_errback = Mock()
|
|
common.insured(pool, fun, (2, 2), {'foo': 'bar'},
|
|
errback=custom_errback)
|
|
conn.ensure_connection.assert_called_with(errback=custom_errback)
|
|
|
|
|
|
class MockConsumer:
|
|
consumers = set()
|
|
|
|
def __init__(self, channel, queues=None, callbacks=None, **kwargs):
|
|
self.channel = channel
|
|
self.queues = queues
|
|
self.callbacks = callbacks
|
|
|
|
def __enter__(self):
|
|
self.consumers.add(self)
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None
|
|
) -> None:
|
|
self.consumers.discard(self)
|
|
|
|
|
|
class test_itermessages:
|
|
|
|
class MockConnection:
|
|
should_raise_timeout = False
|
|
|
|
def drain_events(self, **kwargs):
|
|
if self.should_raise_timeout:
|
|
raise socket.timeout()
|
|
for consumer in MockConsumer.consumers:
|
|
for callback in consumer.callbacks:
|
|
callback('body', 'message')
|
|
|
|
def test_default(self):
|
|
conn = self.MockConnection()
|
|
channel = Mock()
|
|
channel.connection.client = conn
|
|
conn.Consumer = MockConsumer
|
|
it = common.itermessages(conn, channel, 'q', limit=1)
|
|
|
|
ret = next(it)
|
|
assert ret == ('body', 'message')
|
|
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
def test_when_raises_socket_timeout(self):
|
|
conn = self.MockConnection()
|
|
conn.should_raise_timeout = True
|
|
channel = Mock()
|
|
channel.connection.client = conn
|
|
conn.Consumer = MockConsumer
|
|
it = common.itermessages(conn, channel, 'q', limit=1)
|
|
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
@patch('kombu.common.deque')
|
|
def test_when_raises_IndexError(self, deque):
|
|
deque_instance = deque.return_value = Mock()
|
|
deque_instance.popleft.side_effect = IndexError()
|
|
conn = self.MockConnection()
|
|
channel = Mock()
|
|
conn.Consumer = MockConsumer
|
|
it = common.itermessages(conn, channel, 'q', limit=1)
|
|
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
|
|
class test_QoS:
|
|
|
|
class _QoS(QoS):
|
|
def __init__(self, value):
|
|
self.value = value
|
|
super().__init__(None, value)
|
|
|
|
def set(self, value):
|
|
return value
|
|
|
|
def test_qos_exceeds_16bit(self):
|
|
with patch('kombu.common.logger') as logger:
|
|
callback = Mock()
|
|
qos = QoS(callback, 10)
|
|
qos.prev = 100
|
|
# cannot use 2 ** 32 because of a bug on macOS Py2.5:
|
|
# https://jira.mongodb.org/browse/PYTHON-389
|
|
qos.set(4294967296)
|
|
logger.warning.assert_called()
|
|
callback.assert_called_with(prefetch_count=0)
|
|
|
|
def test_qos_increment_decrement(self):
|
|
qos = self._QoS(10)
|
|
assert qos.increment_eventually() == 11
|
|
assert qos.increment_eventually(3) == 14
|
|
assert qos.increment_eventually(-30) == 14
|
|
assert qos.decrement_eventually(7) == 7
|
|
assert qos.decrement_eventually() == 6
|
|
|
|
def test_qos_disabled_increment_decrement(self):
|
|
qos = self._QoS(0)
|
|
assert qos.increment_eventually() == 0
|
|
assert qos.increment_eventually(3) == 0
|
|
assert qos.increment_eventually(-30) == 0
|
|
assert qos.decrement_eventually(7) == 0
|
|
assert qos.decrement_eventually() == 0
|
|
assert qos.decrement_eventually(10) == 0
|
|
|
|
def test_qos_thread_safe(self):
|
|
qos = self._QoS(10)
|
|
|
|
def add():
|
|
for i in range(1000):
|
|
qos.increment_eventually()
|
|
|
|
def sub():
|
|
for i in range(1000):
|
|
qos.decrement_eventually()
|
|
|
|
def threaded(funs):
|
|
from threading import Thread
|
|
threads = [Thread(target=fun) for fun in funs]
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
threaded([add, add])
|
|
assert qos.value == 2010
|
|
|
|
qos.value = 1000
|
|
threaded([add, sub]) # n = 2
|
|
assert qos.value == 1000
|
|
|
|
def test_exceeds_short(self):
|
|
qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
|
|
qos.update()
|
|
assert qos.value == PREFETCH_COUNT_MAX - 1
|
|
qos.increment_eventually()
|
|
assert qos.value == PREFETCH_COUNT_MAX
|
|
qos.increment_eventually()
|
|
assert qos.value == PREFETCH_COUNT_MAX + 1
|
|
qos.decrement_eventually()
|
|
assert qos.value == PREFETCH_COUNT_MAX
|
|
qos.decrement_eventually()
|
|
assert qos.value == PREFETCH_COUNT_MAX - 1
|
|
|
|
def test_consumer_increment_decrement(self):
|
|
mconsumer = Mock()
|
|
qos = QoS(mconsumer.qos, 10)
|
|
qos.update()
|
|
assert qos.value == 10
|
|
mconsumer.qos.assert_called_with(prefetch_count=10)
|
|
qos.decrement_eventually()
|
|
qos.update()
|
|
assert qos.value == 9
|
|
mconsumer.qos.assert_called_with(prefetch_count=9)
|
|
qos.decrement_eventually()
|
|
assert qos.value == 8
|
|
mconsumer.qos.assert_called_with(prefetch_count=9)
|
|
assert {'prefetch_count': 9} in mconsumer.qos.call_args
|
|
|
|
# Does not decrement 0 value
|
|
qos.value = 0
|
|
qos.decrement_eventually()
|
|
assert qos.value == 0
|
|
qos.increment_eventually()
|
|
assert qos.value == 0
|
|
|
|
def test_consumer_decrement_eventually(self):
|
|
mconsumer = Mock()
|
|
qos = QoS(mconsumer.qos, 10)
|
|
qos.decrement_eventually()
|
|
assert qos.value == 9
|
|
qos.value = 0
|
|
qos.decrement_eventually()
|
|
assert qos.value == 0
|
|
|
|
def test_set(self):
|
|
mconsumer = Mock()
|
|
qos = QoS(mconsumer.qos, 10)
|
|
qos.set(12)
|
|
assert qos.prev == 12
|
|
qos.set(qos.prev)
|