from __future__ import absolute_import, unicode_literals import pytest import socket import types from collections import defaultdict from itertools import count from case import ANY, ContextMock, Mock, call, mock, skip, patch from kombu import Connection, Exchange, Queue, Consumer, Producer from kombu.exceptions import InconsistencyError, VersionMismatch from kombu.five import Empty, Queue as _Queue, bytes_if_py2 from kombu.transport import virtual from kombu.utils import eventio # patch poll from kombu.utils.json import dumps class _poll(eventio._select): def register(self, fd, flags): if flags & eventio.READ: self._rfd.add(fd) def poll(self, timeout): events = [] for fd in self._rfd: if fd.data: events.append((fd.fileno(), eventio.READ)) return events eventio.poll = _poll # must import after poller patch, pep8 complains from kombu.transport import redis # noqa class ResponseError(Exception): pass class Client(object): queues = {} sets = defaultdict(set) hashes = defaultdict(dict) shard_hint = None def __init__(self, db=None, port=None, connection_pool=None, **kwargs): self._called = [] self._connection = None self.bgsave_raises_ResponseError = False self.connection = self._sconnection(self) def bgsave(self): self._called.append('BGSAVE') if self.bgsave_raises_ResponseError: raise ResponseError() def delete(self, key): self.queues.pop(key, None) def exists(self, key): return key in self.queues or key in self.sets def hset(self, key, k, v): self.hashes[key][k] = v def hget(self, key, k): return self.hashes[key].get(k) def hdel(self, key, k): self.hashes[key].pop(k, None) def sadd(self, key, member, *args): self.sets[key].add(member) def zadd(self, key, score1, member1, *args): self.sets[key].add(member1) def smembers(self, key): return self.sets.get(key, set()) def ping(self, *args, **kwargs): return True def srem(self, key, *args): self.sets.pop(key, None) zrem = srem def llen(self, key): try: return self.queues[key].qsize() except KeyError: return 0 def lpush(self, key, value): self.queues[key].put_nowait(value) def parse_response(self, connection, type, **options): cmd, queues = self.connection._sock.data.pop() queues = list(queues) assert cmd == type self.connection._sock.data = [] if type == 'BRPOP': timeout = queues.pop() item = self.brpop(queues, timeout) if item: return item raise Empty() def brpop(self, keys, timeout=None): for key in keys: try: item = self.queues[key].get_nowait() except Empty: pass else: return key, item def rpop(self, key): try: return self.queues[key].get_nowait() except (KeyError, Empty): pass def __contains__(self, k): return k in self._called def pipeline(self): return Pipeline(self) def encode(self, value): return str(value) def _new_queue(self, key): self.queues[key] = _Queue() class _sconnection(object): disconnected = False class _socket(object): blocking = True filenos = count(30) def __init__(self, *args): self._fileno = next(self.filenos) self.data = [] def fileno(self): return self._fileno def setblocking(self, blocking): self.blocking = blocking def __init__(self, client): self.client = client self._sock = self._socket() def disconnect(self): self.disconnected = True def send_command(self, cmd, *args): self._sock.data.append((cmd, args)) def info(self): return {'foo': 1} def pubsub(self, *args, **kwargs): connection = self.connection class ConnectionPool(object): def get_connection(self, *args, **kwargs): return connection self.connection_pool = ConnectionPool() return self class Pipeline(object): def __init__(self, client): self.client = client self.stack = [] def __enter__(self): return self def __exit__(self, *exc_info): pass def __getattr__(self, key): if key not in self.__dict__: def _add(*args, **kwargs): self.stack.append((getattr(self.client, key), args, kwargs)) return self return _add return self.__dict__[key] def execute(self): stack = list(self.stack) self.stack[:] = [] return [fun(*args, **kwargs) for fun, args, kwargs in stack] class Channel(redis.Channel): def _get_client(self): return Client def _get_pool(self, async=False): return Mock() def _get_response_error(self): return ResponseError def _new_queue(self, queue, **kwargs): for pri in self.priority_steps: self.client._new_queue(self._q_for_pri(queue, pri)) def pipeline(self): return Pipeline(Client()) class Transport(redis.Transport): Channel = Channel def _get_errors(self): return ((KeyError,), (IndexError,)) @skip.unless_module('redis') class test_Channel: def setup(self): self.connection = self.create_connection() self.channel = self.connection.default_channel def create_connection(self, **kwargs): kwargs.setdefault('transport_options', {'fanout_patterns': True}) return Connection(transport=Transport, **kwargs) def _get_one_delivery_tag(self, n='test_uniq_tag'): with self.create_connection() as conn1: chan = conn1.default_channel chan.exchange_declare(n) chan.queue_declare(n) chan.queue_bind(n, n, n) msg = chan.prepare_message('quick brown fox') chan.basic_publish(msg, n, n) payload = chan._get(n) assert payload pymsg = chan.message_to_python(payload) return pymsg.delivery_tag def test_delivery_tag_is_uuid(self): seen = set() for i in range(100): tag = self._get_one_delivery_tag() assert tag not in seen seen.add(tag) with pytest.raises(ValueError): int(tag) assert len(tag) == 36 def test_disable_ack_emulation(self): conn = Connection(transport=Transport, transport_options={ 'ack_emulation': False, }) chan = conn.channel() assert not chan.ack_emulation assert chan.QoS == virtual.QoS def test_redis_ping_raises(self): pool = Mock(name='pool') pool_at_init = [pool] client = Mock(name='client') class XChannel(Channel): def __init__(self, *args, **kwargs): self._pool = pool_at_init[0] super(XChannel, self).__init__(*args, **kwargs) def _get_client(self): return lambda *_, **__: client class XTransport(Transport): Channel = XChannel conn = Connection(transport=XTransport) client.ping.side_effect = RuntimeError() with pytest.raises(RuntimeError): conn.channel() pool.disconnect.assert_called_with() pool.disconnect.reset_mock() pool_at_init = [None] with pytest.raises(RuntimeError): conn.channel() pool.disconnect.assert_not_called() def test_after_fork(self): self.channel._pool = None self.channel._after_fork() pool = self.channel._pool = Mock(name='pool') self.channel._after_fork() pool.disconnect.assert_called_with() def test_next_delivery_tag(self): assert (self.channel._next_delivery_tag() != self.channel._next_delivery_tag()) def test_do_restore_message(self): client = Mock(name='client') pl1 = {'body': 'BODY'} spl1 = dumps(pl1) lookup = self.channel._lookup = Mock(name='_lookup') lookup.return_value = {'george', 'elaine'} self.channel._do_restore_message( pl1, 'ex', 'rkey', client, ) client.rpush.assert_has_calls([ call('george', spl1), call('elaine', spl1), ], any_order=True) client = Mock(name='client') pl2 = {'body': 'BODY2', 'headers': {'x-funny': 1}} headers_after = dict(pl2['headers'], redelivered=True) spl2 = dumps(dict(pl2, headers=headers_after)) self.channel._do_restore_message( pl2, 'ex', 'rkey', client, ) client.rpush.assert_any_call('george', spl2) client.rpush.assert_any_call('elaine', spl2) client.rpush.side_effect = KeyError() with patch('kombu.transport.redis.crit') as crit: self.channel._do_restore_message( pl2, 'ex', 'rkey', client, ) crit.assert_called() def test_restore(self): message = Mock(name='message') with patch('kombu.transport.redis.loads') as loads: loads.return_value = 'M', 'EX', 'RK' client = self.channel._create_client = Mock(name='client') client = client() client.pipeline = ContextMock() restore = self.channel._do_restore_message = Mock( name='_do_restore_message', ) pipe = client.pipeline.return_value pipe_hget = Mock(name='pipe.hget') pipe.hget.return_value = pipe_hget pipe_hget_hdel = Mock(name='pipe.hget.hdel') pipe_hget.hdel.return_value = pipe_hget_hdel result = Mock(name='result') pipe_hget_hdel.execute.return_value = None, None self.channel._restore(message) client.pipeline.assert_called_with() unacked_key = self.channel.unacked_key loads.assert_not_called() tag = message.delivery_tag pipe.hget.assert_called_with(unacked_key, tag) pipe_hget.hdel.assert_called_with(unacked_key, tag) pipe_hget_hdel.execute.assert_called_with() pipe_hget_hdel.execute.return_value = result, None self.channel._restore(message) loads.assert_called_with(result) restore.assert_called_with('M', 'EX', 'RK', client, False) def test_qos_restore_visible(self): client = self.channel._create_client = Mock(name='client') client = client() def pipe(*args, **kwargs): return Pipeline(client) client.pipeline = pipe client.zrevrangebyscore.return_value = [ (1, 10), (2, 20), (3, 30), ] qos = redis.QoS(self.channel) restore = qos.restore_by_tag = Mock(name='restore_by_tag') qos._vrestore_count = 1 qos.restore_visible() client.zrevrangebyscore.assert_not_called() assert qos._vrestore_count == 2 qos._vrestore_count = 0 qos.restore_visible() restore.assert_has_calls([ call(1, client), call(2, client), call(3, client), ]) assert qos._vrestore_count == 1 qos._vrestore_count = 0 restore.reset_mock() client.zrevrangebyscore.return_value = [] qos.restore_visible() restore.assert_not_called() assert qos._vrestore_count == 1 qos._vrestore_count = 0 client.setnx.side_effect = redis.MutexHeld() qos.restore_visible() def test_basic_consume_when_fanout_queue(self): self.channel.exchange_declare(exchange='txconfan', type='fanout') self.channel.queue_declare(queue='txconfanq') self.channel.queue_bind(queue='txconfanq', exchange='txconfan') assert 'txconfanq' in self.channel._fanout_queues self.channel.basic_consume('txconfanq', False, None, 1) assert 'txconfanq' in self.channel.active_fanout_queues assert self.channel._fanout_to_queue.get('txconfan') == 'txconfanq' def test_basic_cancel_unknown_delivery_tag(self): assert self.channel.basic_cancel('txaseqwewq') is None def test_subscribe_no_queues(self): self.channel.subclient = Mock() self.channel.active_fanout_queues.clear() self.channel._subscribe() self.channel.subclient.subscribe.assert_not_called() def test_subscribe(self): self.channel.subclient = Mock() self.channel.active_fanout_queues.add('a') self.channel.active_fanout_queues.add('b') self.channel._fanout_queues.update(a=('a', ''), b=('b', '')) self.channel._subscribe() self.channel.subclient.psubscribe.assert_called() s_args, _ = self.channel.subclient.psubscribe.call_args assert sorted(s_args[0]) == ['/{db}.a', '/{db}.b'] self.channel.subclient.connection._sock = None self.channel._subscribe() self.channel.subclient.connection.connect.assert_called_with() def test_handle_unsubscribe_message(self): s = self.channel.subclient s.subscribed = True self.channel._handle_message(s, ['unsubscribe', 'a', 0]) assert not s.subscribed def test_handle_pmessage_message(self): res = self.channel._handle_message( self.channel.subclient, ['pmessage', 'pattern', 'channel', 'data'], ) assert res == { 'type': 'pmessage', 'pattern': 'pattern', 'channel': 'channel', 'data': 'data', } def test_handle_message(self): res = self.channel._handle_message( self.channel.subclient, ['type', 'channel', 'data'], ) assert res == { 'type': 'type', 'pattern': None, 'channel': 'channel', 'data': 'data', } def test_brpop_start_but_no_queues(self): assert self.channel._brpop_start() is None def test_receive(self): s = self.channel.subclient = Mock() self.channel._fanout_to_queue['a'] = 'b' self.channel.connection._deliver = Mock(name='_deliver') message = { 'body': 'hello', 'properties': { 'delivery_tag': 1, 'delivery_info': {'exchange': 'E', 'routing_key': 'R'}, }, } s.parse_response.return_value = ['message', 'a', dumps(message)] self.channel._receive_one(self.channel.subclient) self.channel.connection._deliver.assert_called_once_with( message, 'b', ) def test_receive_raises_for_connection_error(self): self.channel._in_listen = True s = self.channel.subclient = Mock() s.parse_response.side_effect = KeyError('foo') with pytest.raises(KeyError): self.channel._receive_one(self.channel.subclient) assert not self.channel._in_listen def test_receive_empty(self): s = self.channel.subclient = Mock() s.parse_response.return_value = None assert self.channel._receive_one(self.channel.subclient) is None def test_receive_different_message_Type(self): s = self.channel.subclient = Mock() s.parse_response.return_value = ['message', '/foo/', 0, 'data'] assert self.channel._receive_one(self.channel.subclient) is None def test_brpop_read_raises(self): c = self.channel.client = Mock() c.parse_response.side_effect = KeyError('foo') with pytest.raises(KeyError): self.channel._brpop_read() c.connection.disconnect.assert_called_with() def test_brpop_read_gives_None(self): c = self.channel.client = Mock() c.parse_response.return_value = None with pytest.raises(redis.Empty): self.channel._brpop_read() def test_poll_error(self): c = self.channel.client = Mock() c.parse_response = Mock() self.channel._poll_error('BRPOP') c.parse_response.assert_called_with(c.connection, 'BRPOP') c.parse_response.side_effect = KeyError('foo') with pytest.raises(KeyError): self.channel._poll_error('BRPOP') def test_poll_error_on_type_LISTEN(self): c = self.channel.subclient = Mock() c.parse_response = Mock() self.channel._poll_error('LISTEN') c.parse_response.assert_called_with() c.parse_response.side_effect = KeyError('foo') with pytest.raises(KeyError): self.channel._poll_error('LISTEN') def test_put_fanout(self): self.channel._in_poll = False c = self.channel._create_client = Mock() body = {'hello': 'world'} self.channel._put_fanout('exchange', body, '') c().publish.assert_called_with('/{db}.exchange', dumps(body)) def test_put_priority(self): client = self.channel._create_client = Mock(name='client') msg1 = {'properties': {'priority': 3}} self.channel._put('george', msg1) client().lpush.assert_called_with( self.channel._q_for_pri('george', 3), dumps(msg1), ) msg2 = {'properties': {'priority': 313}} self.channel._put('george', msg2) client().lpush.assert_called_with( self.channel._q_for_pri('george', 9), dumps(msg2), ) msg3 = {'properties': {}} self.channel._put('george', msg3) client().lpush.assert_called_with( self.channel._q_for_pri('george', 0), dumps(msg3), ) def test_delete(self): x = self.channel x._create_client = Mock() x._create_client.return_value = x.client delete = x.client.delete = Mock() srem = x.client.srem = Mock() x._delete('queue', 'exchange', 'routing_key', None) delete.assert_has_calls([ call(x._q_for_pri('queue', pri)) for pri in redis.PRIORITY_STEPS ]) srem.assert_called_with(x.keyprefix_queue % ('exchange',), x.sep.join(['routing_key', '', 'queue'])) def test_has_queue(self): self.channel._create_client = Mock() self.channel._create_client.return_value = self.channel.client exists = self.channel.client.exists = Mock() exists.return_value = True assert self.channel._has_queue('foo') exists.assert_has_calls([ call(self.channel._q_for_pri('foo', pri)) for pri in redis.PRIORITY_STEPS ]) exists.return_value = False assert not self.channel._has_queue('foo') def test_close_when_closed(self): self.channel.closed = True self.channel.close() def test_close_deletes_autodelete_fanout_queues(self): self.channel._fanout_queues = {'foo': ('foo', ''), 'bar': ('bar', '')} self.channel.auto_delete_queues = ['foo'] self.channel.queue_delete = Mock(name='queue_delete') client = self.channel.client self.channel.close() self.channel.queue_delete.assert_has_calls([ call('foo', client=client), ]) def test_close_client_close_raises(self): c = self.channel.client = Mock() connection = c.connection connection.disconnect.side_effect = self.channel.ResponseError() self.channel.close() connection.disconnect.assert_called_with() def test_invalid_database_raises_ValueError(self): with pytest.raises(ValueError): self.channel.connection.client.virtual_host = 'dwqeq' self.channel._connparams() def test_connparams_allows_slash_in_db(self): self.channel.connection.client.virtual_host = '/123' assert self.channel._connparams()['db'] == 123 def test_connparams_db_can_be_int(self): self.channel.connection.client.virtual_host = 124 assert self.channel._connparams()['db'] == 124 def test_new_queue_with_auto_delete(self): redis.Channel._new_queue(self.channel, 'george', auto_delete=False) assert 'george' not in self.channel.auto_delete_queues redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True) assert 'elaine' in self.channel.auto_delete_queues def test_connparams_regular_hostname(self): self.channel.connection.client.hostname = 'george.vandelay.com' assert self.channel._connparams()['host'] == 'george.vandelay.com' def test_rotate_cycle_ValueError(self): cycle = self.channel._queue_cycle cycle.update(['kramer', 'jerry']) cycle.rotate('kramer') assert cycle.items, ['jerry' == 'kramer'] cycle.rotate('elaine') def test_get_client(self): import redis as R KombuRedis = redis.Channel._get_client(self.channel) assert KombuRedis Rv = getattr(R, 'VERSION', None) try: R.VERSION = (2, 4, 0) with pytest.raises(VersionMismatch): redis.Channel._get_client(self.channel) finally: if Rv is not None: R.VERSION = Rv def test_get_response_error(self): from redis.exceptions import ResponseError assert redis.Channel._get_response_error(self.channel) is ResponseError def test_avail_client(self): self.channel._pool = Mock() cc = self.channel._create_client = Mock() with self.channel.conn_or_acquire(): pass cc.assert_called_with() def test_register_with_event_loop(self): transport = self.connection.transport transport.cycle = Mock(name='cycle') transport.cycle.fds = {12: 'LISTEN', 13: 'BRPOP'} conn = Mock(name='conn') loop = Mock(name='loop') redis.Transport.register_with_event_loop(transport, conn, loop) transport.cycle.on_poll_init.assert_called_with(loop.poller) loop.call_repeatedly.assert_called_with( 10, transport.cycle.maybe_restore_messages, ) loop.on_tick.add.assert_called() on_poll_start = loop.on_tick.add.call_args[0][0] on_poll_start() transport.cycle.on_poll_start.assert_called_with() loop.add_reader.assert_has_calls([ call(12, transport.on_readable, 12), call(13, transport.on_readable, 13), ]) def test_transport_on_readable(self): transport = self.connection.transport cycle = transport.cycle = Mock(name='cyle') cycle.on_readable.return_value = None redis.Transport.on_readable(transport, 13) cycle.on_readable.assert_called_with(13) def test_transport_get_errors(self): assert redis.Transport._get_errors(self.connection.transport) def test_transport_driver_version(self): assert redis.Transport.driver_version(self.connection.transport) def test_transport_get_errors_when_InvalidData_used(self): from redis import exceptions class ID(Exception): pass DataError = getattr(exceptions, 'DataError', None) InvalidData = getattr(exceptions, 'InvalidData', None) exceptions.InvalidData = ID exceptions.DataError = None try: errors = redis.Transport._get_errors(self.connection.transport) assert errors assert ID in errors[1] finally: if DataError is not None: exceptions.DataError = DataError if InvalidData is not None: exceptions.InvalidData = InvalidData def test_empty_queues_key(self): channel = self.channel channel._in_poll = False key = channel.keyprefix_queue % 'celery' # Everything is fine, there is a list of queues. channel.client.sadd(key, 'celery\x06\x16\x06\x16celery') assert channel.get_table('celery') == [ ('celery', '', 'celery'), ] # ... then for some reason, the _kombu.binding.celery key gets lost channel.client.srem(key) # which raises a channel error so that the consumer/publisher # can recover by redeclaring the required entities. with pytest.raises(InconsistencyError): self.channel.get_table('celery') def test_socket_connection(self): with patch('kombu.transport.redis.Channel._create_client'): with Connection('redis+socket:///tmp/redis.sock') as conn: connparams = conn.default_channel._connparams() assert issubclass( connparams['connection_class'], redis.redis.UnixDomainSocketConnection, ) assert connparams['path'] == '/tmp/redis.sock' def test_ssl_argument__dict(self): with patch('kombu.transport.redis.Channel._create_client'): # Expected format for redis-py's SSLConnection class ssl_params = { 'ssl_cert_reqs': 2, 'ssl_ca_certs': '/foo/ca.pem', 'ssl_certfile': '/foo/cert.crt', 'ssl_keyfile': '/foo/pkey.key' } with Connection('redis://', ssl=ssl_params) as conn: params = conn.default_channel._connparams() assert params['ssl_cert_reqs'] == ssl_params['ssl_cert_reqs'] assert params['ssl_ca_certs'] == ssl_params['ssl_ca_certs'] assert params['ssl_certfile'] == ssl_params['ssl_certfile'] assert params['ssl_keyfile'] == ssl_params['ssl_keyfile'] assert params.get('ssl') is None def test_ssl_connection(self): with patch('kombu.transport.redis.Channel._create_client'): with Connection('redis://', ssl={'ssl_cert_reqs': 2}) as conn: connparams = conn.default_channel._connparams() assert issubclass( connparams['connection_class'], redis.redis.SSLConnection, ) @skip.unless_module('redis') class test_Redis: def setup(self): self.connection = Connection(transport=Transport) self.exchange = Exchange('test_Redis', type='direct') self.queue = Queue('test_Redis', self.exchange, 'test_Redis') def teardown(self): self.connection.close() def test_publish__get(self): channel = self.connection.channel() producer = Producer(channel, self.exchange, routing_key='test_Redis') self.queue(channel).declare() producer.publish({'hello': 'world'}) assert self.queue(channel).get().payload == {'hello': 'world'} assert self.queue(channel).get() is None assert self.queue(channel).get() is None assert self.queue(channel).get() is None def test_publish__consume(self): connection = Connection(transport=Transport) channel = connection.channel() producer = Producer(channel, self.exchange, routing_key='test_Redis') consumer = Consumer(channel, queues=[self.queue]) producer.publish({'hello2': 'world2'}) _received = [] def callback(message_data, message): _received.append(message_data) message.ack() consumer.register_callback(callback) consumer.consume() assert channel in channel.connection.cycle._channels try: connection.drain_events(timeout=1) assert _received with pytest.raises(socket.timeout): connection.drain_events(timeout=0.01) finally: channel.close() def test_purge(self): channel = self.connection.channel() producer = Producer(channel, self.exchange, routing_key='test_Redis') self.queue(channel).declare() for i in range(10): producer.publish({'hello': 'world-%s' % (i,)}) assert channel._size('test_Redis') == 10 assert self.queue(channel).purge() == 10 channel.close() def test_db_values(self): Connection(virtual_host=1, transport=Transport).channel() Connection(virtual_host='1', transport=Transport).channel() Connection(virtual_host='/1', transport=Transport).channel() with pytest.raises(Exception): Connection('redis:///foo').channel() def test_db_port(self): c1 = Connection(port=None, transport=Transport).channel() c1.close() c2 = Connection(port=9999, transport=Transport).channel() c2.close() def test_close_poller_not_active(self): c = Connection(transport=Transport).channel() cycle = c.connection.cycle c.client.connection c.close() assert c not in cycle._channels def test_close_ResponseError(self): c = Connection(transport=Transport).channel() c.client.bgsave_raises_ResponseError = True c.close() def test_close_disconnects(self): c = Connection(transport=Transport).channel() conn1 = c.client.connection conn2 = c.subclient.connection c.close() assert conn1.disconnected assert conn2.disconnected def test_get__Empty(self): channel = self.connection.channel() with pytest.raises(Empty): channel._get('does-not-exist') channel.close() def test_get_client(self): with mock.module_exists(*_redis_modules()): conn = Connection(transport=Transport) chan = conn.channel() assert chan.Client assert chan.ResponseError assert conn.transport.connection_errors assert conn.transport.channel_errors def test_check_at_least_we_try_to_connect_and_fail(self): import redis connection = Connection('redis://localhost:65534/') with pytest.raises(redis.exceptions.ConnectionError): chan = connection.channel() chan._size('some_queue') def _redis_modules(): class ConnectionError(Exception): pass class AuthenticationError(Exception): pass class InvalidData(Exception): pass class InvalidResponse(Exception): pass class ResponseError(Exception): pass exceptions = types.ModuleType(bytes_if_py2('redis.exceptions')) exceptions.ConnectionError = ConnectionError exceptions.AuthenticationError = AuthenticationError exceptions.InvalidData = InvalidData exceptions.InvalidResponse = InvalidResponse exceptions.ResponseError = ResponseError class Redis(object): pass myredis = types.ModuleType(bytes_if_py2('redis')) myredis.exceptions = exceptions myredis.Redis = Redis return myredis, exceptions @skip.unless_module('redis') class test_MultiChannelPoller: def setup(self): self.Poller = redis.MultiChannelPoller def test_on_poll_start(self): p = self.Poller() p._channels = [] p.on_poll_start() p._register_BRPOP = Mock(name='_register_BRPOP') p._register_LISTEN = Mock(name='_register_LISTEN') chan1 = Mock(name='chan1') p._channels = [chan1] chan1.active_queues = [] chan1.active_fanout_queues = [] p.on_poll_start() chan1.active_queues = ['q1'] chan1.active_fanout_queues = ['q2'] chan1.qos.can_consume.return_value = False p.on_poll_start() p._register_LISTEN.assert_called_with(chan1) p._register_BRPOP.assert_not_called() chan1.qos.can_consume.return_value = True p._register_LISTEN.reset_mock() p.on_poll_start() p._register_BRPOP.assert_called_with(chan1) p._register_LISTEN.assert_called_with(chan1) def test_on_poll_init(self): p = self.Poller() chan1 = Mock(name='chan1') p._channels = [] poller = Mock(name='poller') p.on_poll_init(poller) assert p.poller is poller p._channels = [chan1] p.on_poll_init(poller) chan1.qos.restore_visible.assert_called_with( num=chan1.unacked_restore_limit, ) def test_handle_event(self): p = self.Poller() chan = Mock(name='chan') p._fd_to_chan[13] = chan, 'BRPOP' chan.handlers = {'BRPOP': Mock(name='BRPOP')} chan.qos.can_consume.return_value = False p.handle_event(13, redis.READ) chan.handlers['BRPOP'].assert_not_called() chan.qos.can_consume.return_value = True p.handle_event(13, redis.READ) chan.handlers['BRPOP'].assert_called_with() p.handle_event(13, redis.ERR) chan._poll_error.assert_called_with('BRPOP') p.handle_event(13, ~(redis.READ | redis.ERR)) def test_fds(self): p = self.Poller() p._fd_to_chan = {1: 2} assert p.fds == p._fd_to_chan def test_close_unregisters_fds(self): p = self.Poller() poller = p.poller = Mock() p._chan_to_sock.update({1: 1, 2: 2, 3: 3}) p.close() assert poller.unregister.call_count == 3 u_args = poller.unregister.call_args_list assert sorted(u_args) == [ ((1,), {}), ((2,), {}), ((3,), {}), ] def test_close_when_unregister_raises_KeyError(self): p = self.Poller() p.poller = Mock() p._chan_to_sock.update({1: 1}) p.poller.unregister.side_effect = KeyError(1) p.close() def test_close_resets_state(self): p = self.Poller() p.poller = Mock() p._channels = Mock() p._fd_to_chan = Mock() p._chan_to_sock = Mock() p._chan_to_sock.itervalues.return_value = [] p._chan_to_sock.values.return_value = [] # py3k p.close() p._channels.clear.assert_called_with() p._fd_to_chan.clear.assert_called_with() p._chan_to_sock.clear.assert_called_with() def test_register_when_registered_reregisters(self): p = self.Poller() p.poller = Mock() channel, client, type = Mock(), Mock(), Mock() sock = client.connection._sock = Mock() sock.fileno.return_value = 10 p._chan_to_sock = {(channel, client, type): 6} p._register(channel, client, type) p.poller.unregister.assert_called_with(6) assert p._fd_to_chan[10] == (channel, type) assert p._chan_to_sock[(channel, client, type)] == sock p.poller.register.assert_called_with(sock, p.eventflags) # when client not connected yet client.connection._sock = None def after_connected(): client.connection._sock = Mock() client.connection.connect.side_effect = after_connected p._register(channel, client, type) client.connection.connect.assert_called_with() def test_register_BRPOP(self): p = self.Poller() channel = Mock() channel.client.connection._sock = None p._register = Mock() channel._in_poll = False p._register_BRPOP(channel) assert channel._brpop_start.call_count == 1 assert p._register.call_count == 1 channel.client.connection._sock = Mock() p._chan_to_sock[(channel, channel.client, 'BRPOP')] = True channel._in_poll = True p._register_BRPOP(channel) assert channel._brpop_start.call_count == 1 assert p._register.call_count == 1 def test_register_LISTEN(self): p = self.Poller() channel = Mock() channel.subclient.connection._sock = None channel._in_listen = False p._register = Mock() p._register_LISTEN(channel) p._register.assert_called_with(channel, channel.subclient, 'LISTEN') assert p._register.call_count == 1 assert channel._subscribe.call_count == 1 channel._in_listen = True p._chan_to_sock[(channel, channel.subclient, 'LISTEN')] = 3 channel.subclient.connection._sock = Mock() p._register_LISTEN(channel) assert p._register.call_count == 1 assert channel._subscribe.call_count == 1 def create_get(self, events=None, queues=None, fanouts=None): _pr = [] if events is None else events _aq = [] if queues is None else queues _af = [] if fanouts is None else fanouts p = self.Poller() p.poller = Mock() p.poller.poll.return_value = _pr p._register_BRPOP = Mock() p._register_LISTEN = Mock() channel = Mock() p._channels = [channel] channel.active_queues = _aq channel.active_fanout_queues = _af return p, channel def test_get_no_actions(self): p, channel = self.create_get() with pytest.raises(redis.Empty): p.get(Mock()) def test_qos_reject(self): p, channel = self.create_get() qos = redis.QoS(channel) qos.ack = Mock(name='Qos.ack') qos.reject(1234) qos.ack.assert_called_with(1234) def test_get_brpop_qos_allow(self): p, channel = self.create_get(queues=['a_queue']) channel.qos.can_consume.return_value = True with pytest.raises(redis.Empty): p.get(Mock()) p._register_BRPOP.assert_called_with(channel) def test_get_brpop_qos_disallow(self): p, channel = self.create_get(queues=['a_queue']) channel.qos.can_consume.return_value = False with pytest.raises(redis.Empty): p.get(Mock()) p._register_BRPOP.assert_not_called() def test_get_listen(self): p, channel = self.create_get(fanouts=['f_queue']) with pytest.raises(redis.Empty): p.get(Mock()) p._register_LISTEN.assert_called_with(channel) def test_get_receives_ERR(self): p, channel = self.create_get(events=[(1, eventio.ERR)]) p._fd_to_chan[1] = (channel, 'BRPOP') with pytest.raises(redis.Empty): p.get(Mock()) channel._poll_error.assert_called_with('BRPOP') def test_get_receives_multiple(self): p, channel = self.create_get(events=[(1, eventio.ERR), (1, eventio.ERR)]) p._fd_to_chan[1] = (channel, 'BRPOP') with pytest.raises(redis.Empty): p.get(Mock()) channel._poll_error.assert_called_with('BRPOP') @skip.unless_module('redis') class test_Mutex: def test_mutex(self, lock_id='xxx'): client = Mock(name='client') with patch('kombu.transport.redis.uuid') as uuid: # Won uuid.return_value = lock_id client.setnx.return_value = True client.pipeline = ContextMock() pipe = client.pipeline.return_value pipe.get.return_value = lock_id held = False with redis.Mutex(client, 'foo1', 100): held = True assert held client.setnx.assert_called_with('foo1', lock_id) pipe.get.return_value = 'yyy' held = False with redis.Mutex(client, 'foo1', 100): held = True assert held # Did not win client.expire.reset_mock() pipe.get.return_value = lock_id client.setnx.return_value = False with pytest.raises(redis.MutexHeld): held = False with redis.Mutex(client, 'foo1', '100'): held = True assert not held client.ttl.return_value = 0 with pytest.raises(redis.MutexHeld): held = False with redis.Mutex(client, 'foo1', '100'): held = True assert not held client.expire.assert_called() # Wins but raises WatchError (and that is ignored) client.setnx.return_value = True pipe.watch.side_effect = redis.redis.WatchError() held = False with redis.Mutex(client, 'foo1', 100): held = True assert held @skip.unless_module('redis.sentinel') class test_RedisSentinel: def test_method_called(self): from kombu.transport.redis import SentinelChannel with patch.object(SentinelChannel, '_sentinel_managed_pool') as p: connection = Connection( 'sentinel://localhost:65534/', transport_options={ 'master_name': 'not_important', }, ) connection.channel() p.assert_called() def test_getting_master_from_sentinel(self): with patch('redis.sentinel.Sentinel') as patched: connection = Connection( 'sentinel://localhost:65534/', transport_options={ 'master_name': 'not_important', }, ) connection.channel() assert patched master_for = patched.return_value.master_for master_for.assert_called() master_for.assert_called_with('not_important', ANY) master_for().connection_pool.get_connection.assert_called() def test_can_create_connection(self): from redis.exceptions import ConnectionError connection = Connection( 'sentinel://localhost:65534/', transport_options={ 'master_name': 'not_important', }, ) with pytest.raises(ConnectionError): connection.channel()