From 3475986ee04590c90ce76e1eaf821fc80547948b Mon Sep 17 00:00:00 2001 From: vinay karanam Date: Mon, 28 Feb 2022 16:18:09 +0530 Subject: [PATCH] Added global_keyprefix support for pubsub clients (#1495) * Added global_keyprefix support for pubsub clients * Added test cases --- kombu/transport/redis.py | 62 ++++++++++++++++++++++++++++++++-- t/unit/transport/test_redis.py | 20 +++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index 69153052..cc861995 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -216,8 +216,7 @@ class GlobalKeyPrefixMixin: if command in self.PREFIXED_SIMPLE_COMMANDS: args[0] = self.global_keyprefix + str(args[0]) - - if command in self.PREFIXED_COMPLEX_COMMANDS.keys(): + elif command in self.PREFIXED_COMPLEX_COMMANDS: args_start = self.PREFIXED_COMPLEX_COMMANDS[command]["args_start"] args_end = self.PREFIXED_COMPLEX_COMMANDS[command]["args_end"] @@ -267,6 +266,13 @@ class PrefixedStrictRedis(GlobalKeyPrefixMixin, redis.Redis): self.global_keyprefix = kwargs.pop('global_keyprefix', '') redis.Redis.__init__(self, *args, **kwargs) + def pubsub(self, **kwargs): + return PrefixedRedisPubSub( + self.connection_pool, + global_keyprefix=self.global_keyprefix, + **kwargs, + ) + class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline): """Custom Redis pipeline that takes global_keyprefix into consideration. @@ -281,6 +287,58 @@ class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.client.Pipeline): redis.client.Pipeline.__init__(self, *args, **kwargs) +class PrefixedRedisPubSub(redis.client.PubSub): + """Redis pubsub client that takes global_keyprefix into consideration.""" + + PUBSUB_COMMANDS = ( + "SUBSCRIBE", + "UNSUBSCRIBE", + "PSUBSCRIBE", + "PUNSUBSCRIBE", + ) + + def __init__(self, *args, **kwargs): + self.global_keyprefix = kwargs.pop('global_keyprefix', '') + super().__init__(*args, **kwargs) + + def _prefix_args(self, args): + args = list(args) + command = args.pop(0) + + if command in self.PUBSUB_COMMANDS: + args = [ + self.global_keyprefix + str(arg) + for arg in args + ] + + return [command, *args] + + def parse_response(self, *args, **kwargs): + """Parse a response from the Redis server. + + Method wraps ``PubSub.parse_response()`` to remove prefixes of keys + returned by redis command. + """ + ret = super().parse_response(*args, **kwargs) + if ret is None: + return ret + + # response formats + # SUBSCRIBE and UNSUBSCRIBE + # -> [message type, channel, message] + # PSUBSCRIBE and PUNSUBSCRIBE + # -> [message type, pattern, channel, message] + message_type, *channels, message = ret + return [ + message_type, + *[channel[len(self.global_keyprefix):] for channel in channels], + message, + ] + + def execute_command(self, *args, **kwargs): + return super().execute_command(*self._prefix_args(args), **kwargs) + + class QoS(virtual.QoS): """Redis Ack Emulation.""" diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index ce448b37..029f7901 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -1114,6 +1114,26 @@ class test_Channel: '\x06\x16\x06\x16queue' ) + @patch("redis.client.PubSub.execute_command") + def test_global_keyprefix_pubsub(self, mock_execute_command): + from kombu.transport.redis import PrefixedStrictRedis + + with Connection(transport=Transport) as conn: + client = PrefixedStrictRedis(global_keyprefix='foo_') + + channel = conn.channel() + channel.global_keyprefix = 'foo_' + channel._create_client = Mock() + channel._create_client.return_value = client + channel.subclient.connection = Mock() + channel.active_fanout_queues.add('a') + + channel._subscribe() + mock_execute_command.assert_called_with( + 'PSUBSCRIBE', + 'foo_/{db}.a', + ) + class test_Redis: