mirror of https://github.com/celery/kombu.git
240 lines
7.7 KiB
Python
240 lines
7.7 KiB
Python
from __future__ import annotations
|
|
|
|
import socket
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from kombu.mixins import ConsumerMixin
|
|
from t.mocks import ContextMock
|
|
|
|
|
|
def Message(body, content_type='text/plain', content_encoding='utf-8'):
|
|
m = Mock(name='Message')
|
|
m.body = body
|
|
m.content_type = content_type
|
|
m.content_encoding = content_encoding
|
|
return m
|
|
|
|
|
|
class Cons(ConsumerMixin):
|
|
|
|
def __init__(self, consumers):
|
|
self.calls = Mock(name='ConsumerMixin')
|
|
self.calls.get_consumers.return_value = consumers
|
|
self.get_consumers = self.calls.get_consumers
|
|
self.on_connection_revived = self.calls.on_connection_revived
|
|
self.on_consume_ready = self.calls.on_consume_ready
|
|
self.on_consume_end = self.calls.on_consume_end
|
|
self.on_iteration = self.calls.on_iteration
|
|
self.on_decode_error = self.calls.on_decode_error
|
|
self.on_connection_error = self.calls.on_connection_error
|
|
self.extra_context = ContextMock(name='extra_context')
|
|
self.extra_context.return_value = self.extra_context
|
|
|
|
|
|
class test_ConsumerMixin:
|
|
|
|
def _context(self):
|
|
Acons = ContextMock(name='consumerA')
|
|
Bcons = ContextMock(name='consumerB')
|
|
c = Cons([Acons, Bcons])
|
|
_conn = c.connection = ContextMock(name='connection')
|
|
est = c.establish_connection = Mock(name='est_connection')
|
|
est.return_value = _conn
|
|
return c, Acons, Bcons
|
|
|
|
def test_consume(self):
|
|
c, Acons, Bcons = self._context()
|
|
c.should_stop = False
|
|
it = c.consume(no_ack=True)
|
|
next(it)
|
|
Acons.__enter__.assert_called_with()
|
|
Bcons.__enter__.assert_called_with()
|
|
c.extra_context.__enter__.assert_called_with()
|
|
c.on_consume_ready.assert_called()
|
|
c.on_iteration.assert_called_with()
|
|
c.connection.drain_events.assert_called_with(timeout=1)
|
|
next(it)
|
|
next(it)
|
|
next(it)
|
|
c.should_stop = True
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
def test_consume_drain_raises_socket_error(self):
|
|
c, Acons, Bcons = self._context()
|
|
c.should_stop = False
|
|
it = c.consume(no_ack=True)
|
|
c.connection.drain_events.side_effect = socket.error
|
|
with pytest.raises(socket.error):
|
|
next(it)
|
|
|
|
def se2(*args, **kwargs):
|
|
c.should_stop = True
|
|
raise OSError()
|
|
c.connection.drain_events.side_effect = se2
|
|
it = c.consume(no_ack=True)
|
|
with pytest.raises(StopIteration):
|
|
next(it)
|
|
|
|
def test_consume_drain_raises_socket_timeout(self):
|
|
c, Acons, Bcons = self._context()
|
|
c.should_stop = False
|
|
it = c.consume(no_ack=True, timeout=1)
|
|
|
|
def se(*args, **kwargs):
|
|
c.should_stop = True
|
|
raise socket.timeout()
|
|
c.connection.drain_events.side_effect = se
|
|
with pytest.raises(socket.error):
|
|
next(it)
|
|
|
|
def test_Consumer_context(self):
|
|
c, Acons, Bcons = self._context()
|
|
|
|
with c.Consumer() as (conn, channel, consumer):
|
|
assert conn is c.connection
|
|
assert channel is conn.default_channel
|
|
c.on_connection_revived.assert_called_with()
|
|
c.get_consumers.assert_called()
|
|
cls = c.get_consumers.call_args[0][0]
|
|
|
|
subcons = cls()
|
|
assert subcons.on_decode_error is c.on_decode_error
|
|
assert subcons.channel is conn.default_channel
|
|
Acons.__enter__.assert_called_with()
|
|
Bcons.__enter__.assert_called_with()
|
|
c.on_consume_end.assert_called_with(conn, channel)
|
|
|
|
|
|
class test_ConsumerMixin_interface:
|
|
|
|
def setup(self):
|
|
self.c = ConsumerMixin()
|
|
|
|
def test_get_consumers(self):
|
|
with pytest.raises(NotImplementedError):
|
|
self.c.get_consumers(Mock(), Mock())
|
|
|
|
def test_on_connection_revived(self):
|
|
assert self.c.on_connection_revived() is None
|
|
|
|
def test_on_consume_ready(self):
|
|
assert self.c.on_consume_ready(Mock(), Mock(), []) is None
|
|
|
|
def test_on_consume_end(self):
|
|
assert self.c.on_consume_end(Mock(), Mock()) is None
|
|
|
|
def test_on_iteration(self):
|
|
assert self.c.on_iteration() is None
|
|
|
|
def test_on_decode_error(self):
|
|
message = Message('foo')
|
|
with patch('kombu.mixins.error') as error:
|
|
self.c.on_decode_error(message, KeyError('foo'))
|
|
error.assert_called()
|
|
message.ack.assert_called_with()
|
|
|
|
def test_on_connection_error(self):
|
|
with patch('kombu.mixins.warn') as warn:
|
|
self.c.on_connection_error(KeyError('foo'), 3)
|
|
warn.assert_called()
|
|
|
|
def test_extra_context(self):
|
|
with self.c.extra_context(Mock(), Mock()):
|
|
pass
|
|
|
|
def test_restart_limit(self):
|
|
assert self.c.restart_limit
|
|
|
|
def test_connection_errors(self):
|
|
conn = Mock(name='connection')
|
|
self.c.connection = conn
|
|
conn.connection_errors = (KeyError,)
|
|
assert self.c.connection_errors == conn.connection_errors
|
|
conn.channel_errors = (ValueError,)
|
|
assert self.c.channel_errors == conn.channel_errors
|
|
|
|
def test__consume_from(self):
|
|
a = ContextMock(name='A')
|
|
b = ContextMock(name='B')
|
|
a.__enter__ = Mock(name='A.__enter__')
|
|
b.__enter__ = Mock(name='B.__enter__')
|
|
with self.c._consume_from(a, b):
|
|
pass
|
|
a.__enter__.assert_called_with()
|
|
b.__enter__.assert_called_with()
|
|
|
|
def test_establish_connection(self):
|
|
conn = ContextMock(name='connection')
|
|
conn.clone.return_value = conn
|
|
self.c.connection = conn
|
|
self.c.connect_max_retries = 3
|
|
|
|
with self.c.establish_connection() as conn:
|
|
assert conn
|
|
conn.ensure_connection.assert_called_with(
|
|
self.c.on_connection_error, 3,
|
|
)
|
|
|
|
def test_maybe_conn_error(self):
|
|
conn = ContextMock(name='connection')
|
|
conn.connection_errors = (KeyError,)
|
|
conn.channel_errors = ()
|
|
|
|
self.c.connection = conn
|
|
|
|
def raises():
|
|
raise KeyError('foo')
|
|
self.c.maybe_conn_error(raises)
|
|
|
|
def test_run(self):
|
|
conn = ContextMock(name='connection')
|
|
self.c.connection = conn
|
|
conn.connection_errors = (KeyError,)
|
|
conn.channel_errors = ()
|
|
consume = self.c.consume = Mock(name='c.consume')
|
|
|
|
def se(*args, **kwargs):
|
|
self.c.should_stop = True
|
|
return [1]
|
|
self.c.should_stop = False
|
|
consume.side_effect = se
|
|
self.c.run()
|
|
|
|
def test_run_restart_rate_limited(self):
|
|
conn = ContextMock(name='connection')
|
|
self.c.connection = conn
|
|
conn.connection_errors = (KeyError,)
|
|
conn.channel_errors = ()
|
|
consume = self.c.consume = Mock(name='c.consume')
|
|
with patch('kombu.mixins.sleep') as sleep:
|
|
counter = [0]
|
|
|
|
def se(*args, **kwargs):
|
|
if counter[0] >= 1:
|
|
self.c.should_stop = True
|
|
counter[0] += 1
|
|
return counter
|
|
self.c.should_stop = False
|
|
consume.side_effect = se
|
|
self.c.run()
|
|
sleep.assert_called()
|
|
|
|
def test_run_raises(self):
|
|
conn = ContextMock(name='connection')
|
|
self.c.connection = conn
|
|
conn.connection_errors = (KeyError,)
|
|
conn.channel_errors = ()
|
|
consume = self.c.consume = Mock(name='c.consume')
|
|
|
|
with patch('kombu.mixins.warn') as warn:
|
|
def se_raises(*args, **kwargs):
|
|
self.c.should_stop = True
|
|
raise KeyError('foo')
|
|
self.c.should_stop = False
|
|
consume.side_effect = se_raises
|
|
self.c.run()
|
|
warn.assert_called()
|