rq/tests/test_worker_registration.py

104 lines
4.3 KiB
Python

from unittest.mock import patch
from rq import Queue, Worker
from rq.utils import ceildiv
from rq.worker_registration import (
REDIS_WORKER_KEYS,
WORKERS_BY_QUEUE_KEY,
clean_worker_registry,
get_keys,
register,
unregister,
)
from tests import RQTestCase
class TestWorkerRegistry(RQTestCase):
def test_worker_registration(self):
"""Ensure worker.key is correctly set in Redis."""
foo_queue = Queue(name='foo', connection=self.connection)
bar_queue = Queue(name='bar', connection=self.connection)
worker = Worker([foo_queue, bar_queue], connection=self.connection)
register(worker)
redis = worker.connection
self.assertTrue(redis.sismember(worker.redis_workers_keys, worker.key))
self.assertEqual(Worker.count(connection=redis), 1)
self.assertTrue(redis.sismember(WORKERS_BY_QUEUE_KEY % foo_queue.name, worker.key))
self.assertEqual(Worker.count(queue=foo_queue), 1)
self.assertTrue(redis.sismember(WORKERS_BY_QUEUE_KEY % bar_queue.name, worker.key))
self.assertEqual(Worker.count(queue=bar_queue), 1)
unregister(worker)
self.assertFalse(redis.sismember(worker.redis_workers_keys, worker.key))
self.assertFalse(redis.sismember(WORKERS_BY_QUEUE_KEY % foo_queue.name, worker.key))
self.assertFalse(redis.sismember(WORKERS_BY_QUEUE_KEY % bar_queue.name, worker.key))
def test_get_keys_by_queue(self):
"""get_keys_by_queue only returns active workers for that queue"""
foo_queue = Queue(name='foo', connection=self.connection)
bar_queue = Queue(name='bar', connection=self.connection)
baz_queue = Queue(name='baz', connection=self.connection)
worker1 = Worker([foo_queue, bar_queue], connection=self.connection)
worker2 = Worker([foo_queue], connection=self.connection)
worker3 = Worker([baz_queue], connection=self.connection)
self.assertEqual(set(), get_keys(foo_queue))
register(worker1)
register(worker2)
register(worker3)
# get_keys(queue) will return worker keys for that queue
self.assertEqual(set([worker1.key, worker2.key]), get_keys(foo_queue))
self.assertEqual(set([worker1.key]), get_keys(bar_queue))
# get_keys(connection=connection) will return all worker keys
self.assertEqual(set([worker1.key, worker2.key, worker3.key]), get_keys(connection=worker1.connection))
# Calling get_keys without arguments raises an exception
self.assertRaises(ValueError, get_keys)
unregister(worker1)
unregister(worker2)
unregister(worker3)
def test_clean_registry(self):
"""clean_registry removes worker keys that don't exist in Redis"""
queue = Queue(name='foo', connection=self.connection)
worker = Worker([queue], connection=self.connection)
register(worker)
redis = worker.connection
self.assertTrue(redis.sismember(worker.redis_workers_keys, worker.key))
self.assertTrue(redis.sismember(REDIS_WORKER_KEYS, worker.key))
clean_worker_registry(queue)
self.assertFalse(redis.sismember(worker.redis_workers_keys, worker.key))
self.assertFalse(redis.sismember(REDIS_WORKER_KEYS, worker.key))
def test_clean_large_registry(self):
"""
clean_registry() splits invalid_keys into multiple lists for set removal to avoid sending more than redis can
receive
"""
worker_count = 11
MAX_KEYS = 6
SREM_CALL_COUNT = 2
queue = Queue(name='foo', connection=self.connection)
for i in range(worker_count):
worker = Worker([queue], connection=self.connection)
register(worker)
# Since we registered 11 workers and set the maximum keys to be deleted in each command to 6,
# `srem` command should be called a total of 4 times.
# `srem` is called twice per invalid key group; once for WORKERS_BY_QUEUE_KEY and once for REDIS_WORKER_KEYS
with patch('rq.worker_registration.MAX_KEYS', MAX_KEYS), patch('redis.client.Pipeline.srem') as mock:
clean_worker_registry(queue)
expected_call_count = (ceildiv(worker_count, MAX_KEYS)) * SREM_CALL_COUNT
self.assertEqual(mock.call_count, expected_call_count)