diff --git a/rq/connections.py b/rq/connections.py index 0d39e432..dfb590a4 100644 --- a/rq/connections.py +++ b/rq/connections.py @@ -1,8 +1,8 @@ import warnings from contextlib import contextmanager -from typing import Optional +from typing import Any, Optional, Tuple, Type -from redis import Redis +from redis import Connection as RedisConnection, Redis, SSLConnection, UnixDomainSocketConnection from .local import LocalStack @@ -42,10 +42,9 @@ def Connection(connection: Optional['Redis'] = None): # noqa yield finally: popped = pop_connection() - assert popped == connection, ( - 'Unexpected Redis connection was popped off the stack. ' - 'Check your Redis connection setup.' - ) + assert ( + popped == connection + ), 'Unexpected Redis connection was popped off the stack. Check your Redis connection setup.' def push_connection(redis: 'Redis'): @@ -118,8 +117,27 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis': return connection +def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]: + connection_kwargs = connection.connection_pool.connection_kwargs.copy() + # Redis does not accept parser_class argument which is sometimes present + # on connection_pool kwargs, for example when hiredis is used + connection_kwargs.pop('parser_class', None) + connection_pool_class = connection.connection_pool.connection_class + if issubclass(connection_pool_class, SSLConnection): + connection_kwargs['ssl'] = True + if issubclass(connection_pool_class, UnixDomainSocketConnection): + # The connection keyword arguments are obtained from + # `UnixDomainSocketConnection`, which expects `path`, but passed to + # `redis.client.Redis`, which expects `unix_socket_path`, renaming + # the key is necessary. + # `path` is not left in the dictionary as that keyword argument is + # not expected by `redis.client.Redis` and would raise an exception. + connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path') + + return connection.__class__, connection_pool_class, connection_kwargs + + _connection_stack = LocalStack() __all__ = ['Connection', 'get_current_connection', 'push_connection', 'pop_connection'] - diff --git a/rq/scheduler.py b/rq/scheduler.py index 6e2ab041..069181d9 100644 --- a/rq/scheduler.py +++ b/rq/scheduler.py @@ -6,10 +6,11 @@ import traceback from datetime import datetime from enum import Enum from multiprocessing import Process -from typing import List +from typing import List, Set -from redis import SSLConnection, UnixDomainSocketConnection +from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection +from .connections import parse_connection from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, DEFAULT_SCHEDULER_FALLBACK_PERIOD from .job import Job from .logutils import setup_loghandlers @@ -38,7 +39,7 @@ class RQScheduler: def __init__( self, queues, - connection, + connection: Redis, interval=1, logging_level=logging.INFO, date_format=DEFAULT_LOGGING_DATE_FORMAT, @@ -46,28 +47,10 @@ class RQScheduler: serializer=None, ): self._queue_names = set(parse_names(queues)) - self._acquired_locks = set() - self._scheduled_job_registries = [] + self._acquired_locks: Set[str] = set() + self._scheduled_job_registries: List[ScheduledJobRegistry] = [] self.lock_acquisition_time = None - - # Copy the connection kwargs before mutating them in order to not change the arguments - # used by the current connection pool to create new connections - self._connection_kwargs = connection.connection_pool.connection_kwargs.copy() - # Redis does not accept parser_class argument which is sometimes present - # on connection_pool kwargs, for example when hiredis is used - self._connection_kwargs.pop('parser_class', None) - self._connection_class = connection.__class__ # client - connection_class = connection.connection_pool.connection_class - if issubclass(connection_class, SSLConnection): - self._connection_kwargs['ssl'] = True - if issubclass(connection_class, UnixDomainSocketConnection): - # The connection keyword arguments are obtained from - # `UnixDomainSocketConnection`, which expects `path`, but passed to - # `redis.client.Redis`, which expects `unix_socket_path`, renaming - # the key is necessary. - # `path` is not left in the dictionary as that keyword argument is - # not expected by `redis.client.Redis` and would raise an exception. - self._connection_kwargs['unix_socket_path'] = self._connection_kwargs.pop('path') + self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection) self.serializer = resolve_serializer(serializer) self._connection = None @@ -87,7 +70,9 @@ class RQScheduler: def connection(self): if self._connection: return self._connection - self._connection = self._connection_class(**self._connection_kwargs) + self._connection = self._connection_class( + connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs) + ) return self._connection @property diff --git a/tests/test_connection.py b/tests/test_connection.py index 393c20d7..0b64d2be 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,6 +1,7 @@ -from redis import Redis +from redis import ConnectionPool, Redis, UnixDomainSocketConnection from rq import Connection, Queue +from rq.connections import parse_connection from tests import RQTestCase, find_empty_redis_database from tests.fixtures import do_nothing @@ -28,10 +29,19 @@ class TestConnectionInheritance(RQTestCase): def test_connection_pass_thru(self): """Connection passed through from queues to jobs.""" - q1 = Queue() + q1 = Queue(connection=self.testconn) with Connection(new_connection()): q2 = Queue() job1 = q1.enqueue(do_nothing) job2 = q2.enqueue(do_nothing) self.assertEqual(q1.connection, job1.connection) self.assertEqual(q2.connection, job2.connection) + + def test_parse_connection(self): + """Test parsing `ssl` and UnixDomainSocketConnection""" + _, _, kwargs = parse_connection(Redis(ssl=True)) + self.assertTrue(kwargs['ssl']) + path = '/tmp/redis.sock' + pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path) + _, _, kwargs = parse_connection(Redis(connection_pool=pool)) + self.assertTrue(kwargs['unix_socket_path'], path) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index b159d691..96cde1c4 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,4 +1,6 @@ import os +import redis + from datetime import datetime, timedelta, timezone from multiprocessing import Process from unittest import mock @@ -16,6 +18,17 @@ from tests import RQTestCase, find_empty_redis_database, ssl_test from .fixtures import kill_worker, say_hello +class CustomRedisConnection(redis.Connection): + """Custom redis connection with a custom arg, used in test_custom_connection_pool""" + + def __init__(self, *args, custom_arg=None, **kwargs): + self.custom_arg = custom_arg + super().__init__(*args, **kwargs) + + def get_custom_arg(self): + return self.custom_arg + + class TestScheduledJobRegistry(RQTestCase): def test_get_jobs_to_enqueue(self): """Getting job ids to enqueue from ScheduledJobRegistry.""" @@ -425,3 +438,34 @@ class TestQueue(RQTestCase): job = queue.enqueue_in(timedelta(seconds=30), say_hello, retry=Retry(3, [2])) self.assertEqual(job.retries_left, 3) self.assertEqual(job.retry_intervals, [2]) + + def test_custom_connection_pool(self): + """Connection pool customizing. Ensure that we can properly set a + custom connection pool class and pass extra arguments""" + custom_conn = redis.Redis( + connection_pool=redis.ConnectionPool( + connection_class=CustomRedisConnection, + db=4, + custom_arg="foo", + ) + ) + + queue = Queue(connection=custom_conn) + scheduler = RQScheduler([queue], connection=custom_conn) + + scheduler_connection = scheduler.connection.connection_pool.get_connection('info') + + self.assertEqual(scheduler_connection.__class__, CustomRedisConnection) + self.assertEqual(scheduler_connection.get_custom_arg(), "foo") + + def test_no_custom_connection_pool(self): + """Connection pool customizing must not interfere if we're using a standard + connection (non-pooled)""" + standard_conn = redis.Redis(db=5) + + queue = Queue(connection=standard_conn) + scheduler = RQScheduler([queue], connection=standard_conn) + + scheduler_connection = scheduler.connection.connection_pool.get_connection('info') + + self.assertEqual(scheduler_connection.__class__, redis.Connection)