diff --git a/rq/worker.py b/rq/worker.py index 1e1bc953..2122e58c 100644 --- a/rq/worker.py +++ b/rq/worker.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: except ImportError: pass from redis import Redis - from redis.client import Pipeline, PubSub + from redis.client import Pipeline, PubSub, PubSubWorkerThread try: from signal import SIGKILL @@ -950,12 +950,33 @@ class BaseWorker: self.clean_registries() Group.clean_registries(connection=self.connection) + def _pubsub_exception_handler(self, exc: Exception, pubsub: "PubSub", pubsub_thread: "PubSubWorkerThread") -> None: + """ + This exception handler allows the pubsub_thread to continue & retry to + connect after a connection problem the same way the main worker loop + indefinitely retries. + redis-py internal mechanism will restore the channels subscriptions + once the connection is re-established. + """ + if isinstance(exc, (redis.exceptions.ConnectionError)): + self.log.error( + "Could not connect to Redis instance: %s Retrying in %d seconds...", + exc, + 2, + ) + time.sleep(2.0) + else: + self.log.warning("Pubsub thread exitin on %s" % exc) + raise + def subscribe(self): """Subscribe to this worker's channel""" self.log.info('Subscribing to channel %s', self.pubsub_channel_name) self.pubsub = self.connection.pubsub() self.pubsub.subscribe(**{self.pubsub_channel_name: self.handle_payload}) - self.pubsub_thread = self.pubsub.run_in_thread(sleep_time=0.2, daemon=True) + self.pubsub_thread = self.pubsub.run_in_thread( + sleep_time=0.2, daemon=True, exception_handler=self._pubsub_exception_handler + ) def get_heartbeat_ttl(self, job: 'Job') -> int: """Get's the TTL for the next heartbeat. diff --git a/tests/fixtures.py b/tests/fixtures.py index 5e8d331f..c3d07584 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -47,7 +47,7 @@ def do_nothing(): pass -def raise_exc(): +def raise_exc(*args, **kwargs): raise Exception('raise_exc error') diff --git a/tests/test_commands.py b/tests/test_commands.py index 0844cb45..b99f3aeb 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,5 +1,6 @@ import time from multiprocessing import Process +from unittest import mock from redis import Redis @@ -9,7 +10,7 @@ from rq.exceptions import InvalidJobOperation, NoSuchJobError from rq.serializers import JSONSerializer from rq.worker import WorkerStatus from tests import RQTestCase -from tests.fixtures import _send_kill_horse_command, _send_shutdown_command, long_running_job +from tests.fixtures import _send_kill_horse_command, _send_shutdown_command, long_running_job, raise_exc_mock def start_work(queue_name, worker_name, connection_kwargs): @@ -35,6 +36,31 @@ class TestCommands(RQTestCase): worker.work() p.join(1) + def test_pubsub_thread_survives_connection_error(self): + """Ensure that the pubsub thread is still alive after its Redis connection is killed""" + connection = self.connection + worker = Worker('foo', connection=connection) + worker.subscribe() + + assert worker.pubsub_thread.is_alive() + + for client in connection.client_list(): + connection.client_kill(client["addr"]) + + time.sleep(0.0) # Allow other threads to run + assert worker.pubsub_thread.is_alive() + + def test_pubsub_thread_exits_other_error(self): + """Ensure that the pubsub thread exits on other than redis.exceptions.ConnectionError""" + connection = self.connection + worker = Worker('foo', connection=connection) + + with mock.patch("redis.client.PubSub.get_message", new_callable=raise_exc_mock): + worker.subscribe() + + worker.pubsub_thread.join() + assert not worker.pubsub_thread.is_alive() + def test_kill_horse_command(self): """Ensure that shutdown command works properly.""" connection = self.connection