mirror of https://github.com/rq/rq.git
Merge branch 'master' of github.com:rq/rq into worker-pool
This commit is contained in:
commit
bf7d0d74e0
|
@ -1,8 +1,8 @@
|
|||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
|
||||
from redis import Redis
|
||||
from redis import Connection as RedisConnection, Redis, SSLConnection, UnixDomainSocketConnection
|
||||
|
||||
from .local import LocalStack
|
||||
|
||||
|
@ -42,10 +42,9 @@ def Connection(connection: Optional['Redis'] = None): # noqa
|
|||
yield
|
||||
finally:
|
||||
popped = pop_connection()
|
||||
assert popped == connection, (
|
||||
'Unexpected Redis connection was popped off the stack. '
|
||||
'Check your Redis connection setup.'
|
||||
)
|
||||
assert (
|
||||
popped == connection
|
||||
), 'Unexpected Redis connection was popped off the stack. Check your Redis connection setup.'
|
||||
|
||||
|
||||
def push_connection(redis: 'Redis'):
|
||||
|
@ -118,8 +117,27 @@ def resolve_connection(connection: Optional['Redis'] = None) -> 'Redis':
|
|||
return connection
|
||||
|
||||
|
||||
def parse_connection(connection: Redis) -> Tuple[Type[Redis], Type[RedisConnection], dict]:
|
||||
connection_kwargs = connection.connection_pool.connection_kwargs.copy()
|
||||
# Redis does not accept parser_class argument which is sometimes present
|
||||
# on connection_pool kwargs, for example when hiredis is used
|
||||
connection_kwargs.pop('parser_class', None)
|
||||
connection_pool_class = connection.connection_pool.connection_class
|
||||
if issubclass(connection_pool_class, SSLConnection):
|
||||
connection_kwargs['ssl'] = True
|
||||
if issubclass(connection_pool_class, UnixDomainSocketConnection):
|
||||
# The connection keyword arguments are obtained from
|
||||
# `UnixDomainSocketConnection`, which expects `path`, but passed to
|
||||
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
|
||||
# the key is necessary.
|
||||
# `path` is not left in the dictionary as that keyword argument is
|
||||
# not expected by `redis.client.Redis` and would raise an exception.
|
||||
connection_kwargs['unix_socket_path'] = connection_kwargs.pop('path')
|
||||
|
||||
return connection.__class__, connection_pool_class, connection_kwargs
|
||||
|
||||
|
||||
_connection_stack = LocalStack()
|
||||
|
||||
|
||||
__all__ = ['Connection', 'get_current_connection', 'push_connection', 'pop_connection']
|
||||
|
||||
|
|
|
@ -6,10 +6,11 @@ import traceback
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from multiprocessing import Process
|
||||
from typing import List
|
||||
from typing import List, Set
|
||||
|
||||
from redis import SSLConnection, UnixDomainSocketConnection
|
||||
from redis import ConnectionPool, Redis, SSLConnection, UnixDomainSocketConnection
|
||||
|
||||
from .connections import parse_connection
|
||||
from .defaults import DEFAULT_LOGGING_DATE_FORMAT, DEFAULT_LOGGING_FORMAT, DEFAULT_SCHEDULER_FALLBACK_PERIOD
|
||||
from .job import Job
|
||||
from .logutils import setup_loghandlers
|
||||
|
@ -38,7 +39,7 @@ class RQScheduler:
|
|||
def __init__(
|
||||
self,
|
||||
queues,
|
||||
connection,
|
||||
connection: Redis,
|
||||
interval=1,
|
||||
logging_level=logging.INFO,
|
||||
date_format=DEFAULT_LOGGING_DATE_FORMAT,
|
||||
|
@ -46,28 +47,10 @@ class RQScheduler:
|
|||
serializer=None,
|
||||
):
|
||||
self._queue_names = set(parse_names(queues))
|
||||
self._acquired_locks = set()
|
||||
self._scheduled_job_registries = []
|
||||
self._acquired_locks: Set[str] = set()
|
||||
self._scheduled_job_registries: List[ScheduledJobRegistry] = []
|
||||
self.lock_acquisition_time = None
|
||||
|
||||
# Copy the connection kwargs before mutating them in order to not change the arguments
|
||||
# used by the current connection pool to create new connections
|
||||
self._connection_kwargs = connection.connection_pool.connection_kwargs.copy()
|
||||
# Redis does not accept parser_class argument which is sometimes present
|
||||
# on connection_pool kwargs, for example when hiredis is used
|
||||
self._connection_kwargs.pop('parser_class', None)
|
||||
self._connection_class = connection.__class__ # client
|
||||
connection_class = connection.connection_pool.connection_class
|
||||
if issubclass(connection_class, SSLConnection):
|
||||
self._connection_kwargs['ssl'] = True
|
||||
if issubclass(connection_class, UnixDomainSocketConnection):
|
||||
# The connection keyword arguments are obtained from
|
||||
# `UnixDomainSocketConnection`, which expects `path`, but passed to
|
||||
# `redis.client.Redis`, which expects `unix_socket_path`, renaming
|
||||
# the key is necessary.
|
||||
# `path` is not left in the dictionary as that keyword argument is
|
||||
# not expected by `redis.client.Redis` and would raise an exception.
|
||||
self._connection_kwargs['unix_socket_path'] = self._connection_kwargs.pop('path')
|
||||
self._connection_class, self._pool_class, self._connection_kwargs = parse_connection(connection)
|
||||
self.serializer = resolve_serializer(serializer)
|
||||
|
||||
self._connection = None
|
||||
|
@ -87,7 +70,9 @@ class RQScheduler:
|
|||
def connection(self):
|
||||
if self._connection:
|
||||
return self._connection
|
||||
self._connection = self._connection_class(**self._connection_kwargs)
|
||||
self._connection = self._connection_class(
|
||||
connection_pool=ConnectionPool(connection_class=self._pool_class, **self._connection_kwargs)
|
||||
)
|
||||
return self._connection
|
||||
|
||||
@property
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from redis import Redis
|
||||
from redis import ConnectionPool, Redis, UnixDomainSocketConnection
|
||||
|
||||
from rq import Connection, Queue
|
||||
from rq.connections import parse_connection
|
||||
from tests import RQTestCase, find_empty_redis_database
|
||||
from tests.fixtures import do_nothing
|
||||
|
||||
|
@ -28,10 +29,19 @@ class TestConnectionInheritance(RQTestCase):
|
|||
|
||||
def test_connection_pass_thru(self):
|
||||
"""Connection passed through from queues to jobs."""
|
||||
q1 = Queue()
|
||||
q1 = Queue(connection=self.testconn)
|
||||
with Connection(new_connection()):
|
||||
q2 = Queue()
|
||||
job1 = q1.enqueue(do_nothing)
|
||||
job2 = q2.enqueue(do_nothing)
|
||||
self.assertEqual(q1.connection, job1.connection)
|
||||
self.assertEqual(q2.connection, job2.connection)
|
||||
|
||||
def test_parse_connection(self):
|
||||
"""Test parsing `ssl` and UnixDomainSocketConnection"""
|
||||
_, _, kwargs = parse_connection(Redis(ssl=True))
|
||||
self.assertTrue(kwargs['ssl'])
|
||||
path = '/tmp/redis.sock'
|
||||
pool = ConnectionPool(connection_class=UnixDomainSocketConnection, path=path)
|
||||
_, _, kwargs = parse_connection(Redis(connection_pool=pool))
|
||||
self.assertTrue(kwargs['unix_socket_path'], path)
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import redis
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from multiprocessing import Process
|
||||
from unittest import mock
|
||||
|
@ -16,6 +18,17 @@ from tests import RQTestCase, find_empty_redis_database, ssl_test
|
|||
from .fixtures import kill_worker, say_hello
|
||||
|
||||
|
||||
class CustomRedisConnection(redis.Connection):
|
||||
"""Custom redis connection with a custom arg, used in test_custom_connection_pool"""
|
||||
|
||||
def __init__(self, *args, custom_arg=None, **kwargs):
|
||||
self.custom_arg = custom_arg
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_custom_arg(self):
|
||||
return self.custom_arg
|
||||
|
||||
|
||||
class TestScheduledJobRegistry(RQTestCase):
|
||||
def test_get_jobs_to_enqueue(self):
|
||||
"""Getting job ids to enqueue from ScheduledJobRegistry."""
|
||||
|
@ -425,3 +438,34 @@ class TestQueue(RQTestCase):
|
|||
job = queue.enqueue_in(timedelta(seconds=30), say_hello, retry=Retry(3, [2]))
|
||||
self.assertEqual(job.retries_left, 3)
|
||||
self.assertEqual(job.retry_intervals, [2])
|
||||
|
||||
def test_custom_connection_pool(self):
|
||||
"""Connection pool customizing. Ensure that we can properly set a
|
||||
custom connection pool class and pass extra arguments"""
|
||||
custom_conn = redis.Redis(
|
||||
connection_pool=redis.ConnectionPool(
|
||||
connection_class=CustomRedisConnection,
|
||||
db=4,
|
||||
custom_arg="foo",
|
||||
)
|
||||
)
|
||||
|
||||
queue = Queue(connection=custom_conn)
|
||||
scheduler = RQScheduler([queue], connection=custom_conn)
|
||||
|
||||
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
|
||||
|
||||
self.assertEqual(scheduler_connection.__class__, CustomRedisConnection)
|
||||
self.assertEqual(scheduler_connection.get_custom_arg(), "foo")
|
||||
|
||||
def test_no_custom_connection_pool(self):
|
||||
"""Connection pool customizing must not interfere if we're using a standard
|
||||
connection (non-pooled)"""
|
||||
standard_conn = redis.Redis(db=5)
|
||||
|
||||
queue = Queue(connection=standard_conn)
|
||||
scheduler = RQScheduler([queue], connection=standard_conn)
|
||||
|
||||
scheduler_connection = scheduler.connection.connection_pool.get_connection('info')
|
||||
|
||||
self.assertEqual(scheduler_connection.__class__, redis.Connection)
|
||||
|
|
Loading…
Reference in New Issue