kombu/t/unit/test_mixins.py

240 lines
7.7 KiB
Python

from __future__ import annotations
import socket
from unittest.mock import Mock, patch
import pytest
from kombu.mixins import ConsumerMixin
from t.mocks import ContextMock
def Message(body, content_type='text/plain', content_encoding='utf-8'):
m = Mock(name='Message')
m.body = body
m.content_type = content_type
m.content_encoding = content_encoding
return m
class Cons(ConsumerMixin):
def __init__(self, consumers):
self.calls = Mock(name='ConsumerMixin')
self.calls.get_consumers.return_value = consumers
self.get_consumers = self.calls.get_consumers
self.on_connection_revived = self.calls.on_connection_revived
self.on_consume_ready = self.calls.on_consume_ready
self.on_consume_end = self.calls.on_consume_end
self.on_iteration = self.calls.on_iteration
self.on_decode_error = self.calls.on_decode_error
self.on_connection_error = self.calls.on_connection_error
self.extra_context = ContextMock(name='extra_context')
self.extra_context.return_value = self.extra_context
class test_ConsumerMixin:
def _context(self):
Acons = ContextMock(name='consumerA')
Bcons = ContextMock(name='consumerB')
c = Cons([Acons, Bcons])
_conn = c.connection = ContextMock(name='connection')
est = c.establish_connection = Mock(name='est_connection')
est.return_value = _conn
return c, Acons, Bcons
def test_consume(self):
c, Acons, Bcons = self._context()
c.should_stop = False
it = c.consume(no_ack=True)
next(it)
Acons.__enter__.assert_called_with()
Bcons.__enter__.assert_called_with()
c.extra_context.__enter__.assert_called_with()
c.on_consume_ready.assert_called()
c.on_iteration.assert_called_with()
c.connection.drain_events.assert_called_with(timeout=1)
next(it)
next(it)
next(it)
c.should_stop = True
with pytest.raises(StopIteration):
next(it)
def test_consume_drain_raises_socket_error(self):
c, Acons, Bcons = self._context()
c.should_stop = False
it = c.consume(no_ack=True)
c.connection.drain_events.side_effect = socket.error
with pytest.raises(socket.error):
next(it)
def se2(*args, **kwargs):
c.should_stop = True
raise OSError()
c.connection.drain_events.side_effect = se2
it = c.consume(no_ack=True)
with pytest.raises(StopIteration):
next(it)
def test_consume_drain_raises_socket_timeout(self):
c, Acons, Bcons = self._context()
c.should_stop = False
it = c.consume(no_ack=True, timeout=1)
def se(*args, **kwargs):
c.should_stop = True
raise socket.timeout()
c.connection.drain_events.side_effect = se
with pytest.raises(socket.error):
next(it)
def test_Consumer_context(self):
c, Acons, Bcons = self._context()
with c.Consumer() as (conn, channel, consumer):
assert conn is c.connection
assert channel is conn.default_channel
c.on_connection_revived.assert_called_with()
c.get_consumers.assert_called()
cls = c.get_consumers.call_args[0][0]
subcons = cls()
assert subcons.on_decode_error is c.on_decode_error
assert subcons.channel is conn.default_channel
Acons.__enter__.assert_called_with()
Bcons.__enter__.assert_called_with()
c.on_consume_end.assert_called_with(conn, channel)
class test_ConsumerMixin_interface:
def setup(self):
self.c = ConsumerMixin()
def test_get_consumers(self):
with pytest.raises(NotImplementedError):
self.c.get_consumers(Mock(), Mock())
def test_on_connection_revived(self):
assert self.c.on_connection_revived() is None
def test_on_consume_ready(self):
assert self.c.on_consume_ready(Mock(), Mock(), []) is None
def test_on_consume_end(self):
assert self.c.on_consume_end(Mock(), Mock()) is None
def test_on_iteration(self):
assert self.c.on_iteration() is None
def test_on_decode_error(self):
message = Message('foo')
with patch('kombu.mixins.error') as error:
self.c.on_decode_error(message, KeyError('foo'))
error.assert_called()
message.ack.assert_called_with()
def test_on_connection_error(self):
with patch('kombu.mixins.warn') as warn:
self.c.on_connection_error(KeyError('foo'), 3)
warn.assert_called()
def test_extra_context(self):
with self.c.extra_context(Mock(), Mock()):
pass
def test_restart_limit(self):
assert self.c.restart_limit
def test_connection_errors(self):
conn = Mock(name='connection')
self.c.connection = conn
conn.connection_errors = (KeyError,)
assert self.c.connection_errors == conn.connection_errors
conn.channel_errors = (ValueError,)
assert self.c.channel_errors == conn.channel_errors
def test__consume_from(self):
a = ContextMock(name='A')
b = ContextMock(name='B')
a.__enter__ = Mock(name='A.__enter__')
b.__enter__ = Mock(name='B.__enter__')
with self.c._consume_from(a, b):
pass
a.__enter__.assert_called_with()
b.__enter__.assert_called_with()
def test_establish_connection(self):
conn = ContextMock(name='connection')
conn.clone.return_value = conn
self.c.connection = conn
self.c.connect_max_retries = 3
with self.c.establish_connection() as conn:
assert conn
conn.ensure_connection.assert_called_with(
self.c.on_connection_error, 3,
)
def test_maybe_conn_error(self):
conn = ContextMock(name='connection')
conn.connection_errors = (KeyError,)
conn.channel_errors = ()
self.c.connection = conn
def raises():
raise KeyError('foo')
self.c.maybe_conn_error(raises)
def test_run(self):
conn = ContextMock(name='connection')
self.c.connection = conn
conn.connection_errors = (KeyError,)
conn.channel_errors = ()
consume = self.c.consume = Mock(name='c.consume')
def se(*args, **kwargs):
self.c.should_stop = True
return [1]
self.c.should_stop = False
consume.side_effect = se
self.c.run()
def test_run_restart_rate_limited(self):
conn = ContextMock(name='connection')
self.c.connection = conn
conn.connection_errors = (KeyError,)
conn.channel_errors = ()
consume = self.c.consume = Mock(name='c.consume')
with patch('kombu.mixins.sleep') as sleep:
counter = [0]
def se(*args, **kwargs):
if counter[0] >= 1:
self.c.should_stop = True
counter[0] += 1
return counter
self.c.should_stop = False
consume.side_effect = se
self.c.run()
sleep.assert_called()
def test_run_raises(self):
conn = ContextMock(name='connection')
self.c.connection = conn
conn.connection_errors = (KeyError,)
conn.channel_errors = ()
consume = self.c.consume = Mock(name='c.consume')
with patch('kombu.mixins.warn') as warn:
def se_raises(*args, **kwargs):
self.c.should_stop = True
raise KeyError('foo')
self.c.should_stop = False
consume.side_effect = se_raises
self.c.run()
warn.assert_called()