Merge branch 'master' of github.com:rq/rq into worker-pool

This commit is contained in:
Selwin Ong 2023-04-25 19:45:20 +07:00
commit bf7d0d74e0
4 changed files with 91 additions and 34 deletions

View File

@ -1,8 +1,8 @@
import warnings
from contextlib import contextmanager
from typing import Optional
from typing import Any, Optional, Tuple, Type
from redis import Redis
from redis import Connection as RedisConnection, Redis, SSLConnection, UnixDomainSocketConnection
from .local import LocalStack
@ -42,10 +42,9 @@ def Connection(connection: Optional['Redis'] = None): # noqa
yield
finally:
popped = pop_connection()
assert popped == connection, (
'Unexpected Redis connection was popped off the stack. '
'Check your Redis connection setup.'
)
assert (
popped == connection
), 'Unexpected Redis connection was popped off the stack. Check your Redis connection setup.'
def push_connection(redis: 'Redis'):
@ -118,8 +117,27 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
return connection
def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
connection_kwargs = connection.connection_pool.connection_kwargs.copy()
# Redis does not accept parser_class argument which is sometimes present
# on connection_pool kwargs, for example when hiredis is used
connection_kwargs.pop('parser_class', None)
connection_pool_class = connection.connection_pool.connection_class
if issubclass(connection_pool_class, SSLConnection):
connection_kwargs['ssl'] = True
if issubclass(connection_pool_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
return connection.__class__, connection_pool_class, connection_kwargs
_connection_stack = LocalStack()
__all__ = ['Connection', 'get_current_connection', 'push_connection', 'pop_connection']

View File

@ -6,10 +6,11 @@ import traceback
from datetime import datetime
from enum import Enum
from multiprocessing import Process
from typing import List
from typing import List, Set
from redis import SSLConnection, UnixDomainSocketConnection
from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
from .connections import parse_connection
from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, DEFAULT_SCHEDULER_FALLBACK_PERIOD
from .job import Job
from .logutils import setup_loghandlers
@ -38,7 +39,7 @@ class RQScheduler:
def __init__(
self,
queues,
connection,
connection: Redis,
interval=1,
logging_level=logging.INFO,
date_format=DEFAULT_LOGGING_DATE_FORMAT,
@ -46,28 +47,10 @@ class RQScheduler:
serializer=None,
):
self._queue_names = set(parse_names(queues))
self._acquired_locks = set()
self._scheduled_job_registries = []
self._acquired_locks: Set[str] = set()
self._scheduled_job_registries: List[ScheduledJobRegistry] = []
self.lock_acquisition_time = None
# Copy the connection kwargs before mutating them in order to not change the arguments
# used by the current connection pool to create new connections
self._connection_kwargs = connection.connection_pool.connection_kwargs.copy()
# Redis does not accept parser_class argument which is sometimes present
# on connection_pool kwargs, for example when hiredis is used
self._connection_kwargs.pop('parser_class', None)
self._connection_class = connection.__class__ # client
connection_class = connection.connection_pool.connection_class
if issubclass(connection_class, SSLConnection):
self._connection_kwargs['ssl'] = True
if issubclass(connection_class, UnixDomainSocketConnection):
# The connection keyword arguments are obtained from
# `UnixDomainSocketConnection`, which expects `path`, but passed to
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
# the key is necessary.
# `path` is not left in the dictionary as that keyword argument is
# not expected by `redis.client.Redis` and would raise an exception.
self._connection_kwargs['unix_socket_path'] = self._connection_kwargs.pop('path')
self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection)
self.serializer = resolve_serializer(serializer)
self._connection = None
@ -87,7 +70,9 @@ class RQScheduler:
def connection(self):
if self._connection:
return self._connection
self._connection = self._connection_class(**self._connection_kwargs)
self._connection = self._connection_class(
connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs)
)
return self._connection
@property

View File

@ -1,6 +1,7 @@
from redis import Redis
from redis import ConnectionPool, Redis, UnixDomainSocketConnection
from rq import Connection, Queue
from rq.connections import parse_connection
from tests import RQTestCase, find_empty_redis_database
from tests.fixtures import do_nothing
@ -28,10 +29,19 @@ class TestConnectionInheritance(RQTestCase):
def test_connection_pass_thru(self):
"""Connection passed through from queues to jobs."""
q1 = Queue()
q1 = Queue(connection=self.testconn)
with Connection(new_connection()):
q2 = Queue()
job1 = q1.enqueue(do_nothing)
job2 = q2.enqueue(do_nothing)
self.assertEqual(q1.connection, job1.connection)
self.assertEqual(q2.connection, job2.connection)
def test_parse_connection(self):
"""Test parsing `ssl` and UnixDomainSocketConnection"""
_, _, kwargs = parse_connection(Redis(ssl=True))
self.assertTrue(kwargs['ssl'])
path = '/tmp/redis.sock'
pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
_, _, kwargs = parse_connection(Redis(connection_pool=pool))
self.assertTrue(kwargs['unix_socket_path'], path)

View File

@ -1,4 +1,6 @@
import os
import redis
from datetime import datetime, timedelta, timezone
from multiprocessing import Process
from unittest import mock
@ -16,6 +18,17 @@ from tests import RQTestCase, find_empty_redis_database, ssl_test
from .fixtures import kill_worker, say_hello
class CustomRedisConnection(redis.Connection):
"""Custom redis connection with a custom arg, used in test_custom_connection_pool"""
def __init__(self, *args, custom_arg=None, **kwargs):
self.custom_arg = custom_arg
super().__init__(*args, **kwargs)
def get_custom_arg(self):
return self.custom_arg
class TestScheduledJobRegistry(RQTestCase):
def test_get_jobs_to_enqueue(self):
"""Getting job ids to enqueue from ScheduledJobRegistry."""
@ -425,3 +438,34 @@ class TestQueue(RQTestCase):
job = queue.enqueue_in(timedelta(seconds=30), say_hello, retry=Retry(3, [2]))
self.assertEqual(job.retries_left, 3)
self.assertEqual(job.retry_intervals, [2])
def test_custom_connection_pool(self):
"""Connection pool customizing. Ensure that we can properly set a
custom connection pool class and pass extra arguments"""
custom_conn = redis.Redis(
connection_pool=redis.ConnectionPool(
connection_class=CustomRedisConnection,
db=4,
custom_arg="foo",
)
)
queue = Queue(connection=custom_conn)
scheduler = RQScheduler([queue], connection=custom_conn)
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
self.assertEqual(scheduler_connection.__class__, CustomRedisConnection)
self.assertEqual(scheduler_connection.get_custom_arg(), "foo")
def test_no_custom_connection_pool(self):
"""Connection pool customizing must not interfere if we're using a standard
connection (non-pooled)"""
standard_conn = redis.Redis(db=5)
queue = Queue(connection=standard_conn)
scheduler = RQScheduler([queue], connection=standard_conn)
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
self.assertEqual(scheduler_connection.__class__, redis.Connection)