diff --git a/docs/reference/kombu.transport.gcpubsub.rst b/docs/reference/kombu.transport.gcpubsub.rst new file mode 100644 index 00000000..bf5fc058 --- /dev/null +++ b/docs/reference/kombu.transport.gcpubsub.rst @@ -0,0 +1,24 @@ +============================================================== + Google Cloud Pub/Sub Transport - ``kombu.transport.gcpubsub`` +============================================================== + +.. currentmodule:: kombu.transport.gcpubsub + +.. automodule:: kombu.transport.gcpubsub + + .. contents:: + :local: + + Transport + --------- + + .. autoclass:: Transport + :members: + :undoc-members: + + Channel + ------- + + .. autoclass:: Channel + :members: + :undoc-members: diff --git a/kombu/transport/__init__.py b/kombu/transport/__init__.py index 64dd35c6..180a27b4 100644 --- a/kombu/transport/__init__.py +++ b/kombu/transport/__init__.py @@ -43,7 +43,8 @@ TRANSPORT_ALIASES = { 'etcd': 'kombu.transport.etcd:Transport', 'azurestoragequeues': 'kombu.transport.azurestoragequeues:Transport', 'azureservicebus': 'kombu.transport.azureservicebus:Transport', - 'pyro': 'kombu.transport.pyro:Transport' + 'pyro': 'kombu.transport.pyro:Transport', + 'gcpubsub': 'kombu.transport.gcpubsub:Transport', } _transport_cache = {} diff --git a/kombu/transport/gcpubsub.py b/kombu/transport/gcpubsub.py new file mode 100644 index 00000000..368c7a38 --- /dev/null +++ b/kombu/transport/gcpubsub.py @@ -0,0 +1,810 @@ +"""GCP Pub/Sub transport module for kombu. + +More information about GCP Pub/Sub: +https://cloud.google.com/pubsub + +Features +======== +* Type: Virtual +* Supports Direct: Yes +* Supports Topic: No +* Supports Fanout: Yes +* Supports Priority: No +* Supports TTL: No + +Connection String +================= + +Connection string has the following formats: + +.. code-block:: + + gcpubsub://projects/project-name + +Transport Options +================= +* ``queue_name_prefix``: (str) Prefix for queue names. +* ``ack_deadline_seconds``: (int) The maximum time after receiving a message + and acknowledging it before pub/sub redelivers the message. +* ``expiration_seconds``: (int) Subscriptions without any subscriber + activity or changes made to their properties are removed after this period. + Examples of subscriber activities include open connections, + active pulls, or successful pushes. +* ``wait_time_seconds``: (int) The maximum time to wait for new messages. + Defaults to 10. +* ``retry_timeout_seconds``: (int) The maximum time to wait before retrying. +* ``bulk_max_messages``: (int) The maximum number of messages to pull in bulk. + Defaults to 32. +""" + +from __future__ import annotations + +import dataclasses +import datetime +import string +import threading +from concurrent.futures import (FIRST_COMPLETED, Future, ThreadPoolExecutor, + wait) +from contextlib import suppress +from os import getpid +from queue import Empty +from threading import Lock +from time import monotonic, sleep +from uuid import NAMESPACE_OID, uuid3 + +from _socket import gethostname +from _socket import timeout as socket_timeout +from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded, + PermissionDenied) +from google.api_core.retry import Retry +from google.cloud import monitoring_v3 +from google.cloud.monitoring_v3 import query +from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient +from google.cloud.pubsub_v1 import exceptions as pubsub_exceptions +from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions +from google.cloud.pubsub_v1.subscriber import \ + exceptions as subscriber_exceptions +from google.pubsub_v1 import gapic_version as package_version + +from kombu.entity import TRANSIENT_DELIVERY_MODE +from kombu.log import get_logger +from kombu.utils.encoding import bytes_to_str, safe_str +from kombu.utils.json import dumps, loads +from kombu.utils.objects import cached_property + +from . import virtual + +logger = get_logger('kombu.transport.gcpubsub') + +# dots are replaced by dash, all other punctuation replaced by underscore. +PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'} +CHARS_REPLACE_TABLE = { + ord('.'): ord('-'), + **{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE}, +} + + +class UnackedIds: + """Threadsafe list of ack_ids.""" + + def __init__(self): + self._list = [] + self._lock = Lock() + + def append(self, val): + # append is atomic + self._list.append(val) + + def extend(self, vals: list): + # extend is atomic + self._list.extend(vals) + + def pop(self, index=-1): + with self._lock: + return self._list.pop(index) + + def remove(self, val): + with self._lock, suppress(ValueError): + self._list.remove(val) + + def __len__(self): + with self._lock: + return len(self._list) + + def __getitem__(self, item): + # getitem is atomic + return self._list[item] + + +class AtomicCounter: + """Threadsafe counter. + + Returns the value after inc/dec operations. + """ + + def __init__(self, initial=0): + self._value = initial + self._lock = Lock() + + def inc(self, n=1): + with self._lock: + self._value += n + return self._value + + def dec(self, n=1): + with self._lock: + self._value -= n + return self._value + + def get(self): + with self._lock: + return self._value + + +@dataclasses.dataclass +class QueueDescriptor: + """Pub/Sub queue descriptor.""" + + name: str + topic_path: str # projects/{project_id}/topics/{topic_id} + subscription_id: str + subscription_path: str # projects/{project_id}/subscriptions/{subscription_id} + unacked_ids: UnackedIds = dataclasses.field(default_factory=UnackedIds) + + +class Channel(virtual.Channel): + """GCP Pub/Sub channel.""" + + supports_fanout = True + do_restore = False # pub/sub does that for us + default_wait_time_seconds = 10 + default_ack_deadline_seconds = 240 + default_expiration_seconds = 86400 + default_retry_timeout_seconds = 300 + default_bulk_max_messages = 32 + + _min_ack_deadline = 10 + _fanout_exchanges = set() + _unacked_extender: threading.Thread = None + _stop_extender = threading.Event() + _n_channels = AtomicCounter() + _queue_cache: dict[str, QueueDescriptor] = {} + _tmp_subscriptions: set[str] = set() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pool = ThreadPoolExecutor() + logger.info('new GCP pub/sub channel: %s', self.conninfo.hostname) + + self.project_id = Transport.parse_uri(self.conninfo.hostname) + if self._n_channels.inc() == 1: + Channel._unacked_extender = threading.Thread( + target=self._extend_unacked_deadline, + daemon=True, + ) + self._stop_extender.clear() + Channel._unacked_extender.start() + + def entity_name(self, name: str, table=CHARS_REPLACE_TABLE) -> str: + """Format AMQP queue name into a valid Pub/Sub queue name.""" + if not name.startswith(self.queue_name_prefix): + name = self.queue_name_prefix + name + + return str(safe_str(name)).translate(table) + + def _queue_bind(self, exchange, routing_key, pattern, queue): + exchange_type = self.typeof(exchange).type + queue = self.entity_name(queue) + logger.debug( + 'binding queue: %s to %s exchange: %s with routing_key: %s', + queue, + exchange_type, + exchange, + routing_key, + ) + + filter_args = {} + if exchange_type == 'direct': + # Direct exchange is implemented as a single subscription + # E.g. for exchange 'test_direct': + # -topic:'test_direct' + # -bound queue:'direct1': + # -subscription: direct1' on topic 'test_direct' + # -filter:routing_key' + filter_args = { + 'filter': f'attributes.routing_key="{routing_key}"' + } + subscription_path = self.subscriber.subscription_path( + self.project_id, queue + ) + message_retention_duration = self.expiration_seconds + elif exchange_type == 'fanout': + # Fanout exchange is implemented as a separate subscription. + # E.g. for exchange 'test_fanout': + # -topic:'test_fanout' + # -bound queue 'fanout1': + # -subscription:'fanout1-uuid' on topic 'test_fanout' + # -bound queue 'fanout2': + # -subscription:'fanout2-uuid' on topic 'test_fanout' + uid = f'{uuid3(NAMESPACE_OID, f"{gethostname()}.{getpid()}")}' + uniq_sub_name = f'{queue}-{uid}' + subscription_path = self.subscriber.subscription_path( + self.project_id, uniq_sub_name + ) + self._tmp_subscriptions.add(subscription_path) + self._fanout_exchanges.add(exchange) + message_retention_duration = 600 + else: + raise NotImplementedError( + f'exchange type {exchange_type} not implemented' + ) + exchange_topic = self._create_topic( + self.project_id, exchange, message_retention_duration + ) + self._create_subscription( + topic_path=exchange_topic, + subscription_path=subscription_path, + filter_args=filter_args, + msg_retention=message_retention_duration, + ) + qdesc = QueueDescriptor( + name=queue, + topic_path=exchange_topic, + subscription_id=queue, + subscription_path=subscription_path, + ) + self._queue_cache[queue] = qdesc + + def _create_topic( + self, + project_id: str, + topic_id: str, + message_retention_duration: int = None, + ) -> str: + topic_path = self.publisher.topic_path(project_id, topic_id) + if self._is_topic_exists(topic_path): + # topic creation takes a while, so skip if possible + logger.debug('topic: %s exists', topic_path) + return topic_path + try: + logger.debug('creating topic: %s', topic_path) + request = {'name': topic_path} + if message_retention_duration: + request[ + 'message_retention_duration' + ] = f'{message_retention_duration}s' + self.publisher.create_topic(request=request) + except AlreadyExists: + pass + + return topic_path + + def _is_topic_exists(self, topic_path: str) -> bool: + topics = self.publisher.list_topics( + request={"project": f'projects/{self.project_id}'} + ) + for t in topics: + if t.name == topic_path: + return True + return False + + def _create_subscription( + self, + project_id: str = None, + topic_id: str = None, + topic_path: str = None, + subscription_path: str = None, + filter_args=None, + msg_retention: int = None, + ) -> str: + subscription_path = ( + subscription_path + or self.subscriber.subscription_path(self.project_id, topic_id) + ) + topic_path = topic_path or self.publisher.topic_path( + project_id, topic_id + ) + try: + logger.debug( + 'creating subscription: %s, topic: %s, filter: %s', + subscription_path, + topic_path, + filter_args, + ) + msg_retention = msg_retention or self.expiration_seconds + self.subscriber.create_subscription( + request={ + "name": subscription_path, + "topic": topic_path, + 'ack_deadline_seconds': self.ack_deadline_seconds, + 'expiration_policy': { + 'ttl': f'{self.expiration_seconds}s' + }, + 'message_retention_duration': f'{msg_retention}s', + **(filter_args or {}), + } + ) + except AlreadyExists: + pass + return subscription_path + + def _delete(self, queue, *args, **kwargs): + """Delete a queue by name.""" + queue = self.entity_name(queue) + logger.info('deleting queue: %s', queue) + qdesc = self._queue_cache.get(queue) + if not qdesc: + return + self.subscriber.delete_subscription( + request={"subscription": qdesc.subscription_path} + ) + self._queue_cache.pop(queue, None) + + def _put(self, queue, message, **kwargs): + """Put a message onto the queue.""" + queue = self.entity_name(queue) + qdesc = self._queue_cache[queue] + routing_key = self._get_routing_key(message) + logger.debug( + 'putting message to queue: %s, topic: %s, routing_key: %s', + queue, + qdesc.topic_path, + routing_key, + ) + encoded_message = dumps(message) + self.publisher.publish( + qdesc.topic_path, + encoded_message.encode("utf-8"), + routing_key=routing_key, + ) + + def _put_fanout(self, exchange, message, routing_key, **kwargs): + """Put a message onto fanout exchange.""" + self._lookup(exchange, routing_key) + topic_path = self.publisher.topic_path(self.project_id, exchange) + logger.debug( + 'putting msg to fanout exchange: %s, topic: %s', + exchange, + topic_path, + ) + encoded_message = dumps(message) + self.publisher.publish( + topic_path, + encoded_message.encode("utf-8"), + retry=Retry(deadline=self.retry_timeout_seconds), + ) + + def _get(self, queue: str, timeout: float = None): + """Retrieves a single message from a queue.""" + queue = self.entity_name(queue) + qdesc = self._queue_cache[queue] + try: + response = self.subscriber.pull( + request={ + 'subscription': qdesc.subscription_path, + 'max_messages': 1, + }, + retry=Retry(deadline=self.retry_timeout_seconds), + timeout=timeout or self.wait_time_seconds, + ) + except DeadlineExceeded: + raise Empty() + + if len(response.received_messages) == 0: + raise Empty() + + message = response.received_messages[0] + ack_id = message.ack_id + payload = loads(message.message.data) + delivery_info = payload['properties']['delivery_info'] + logger.debug( + 'queue:%s got message, ack_id: %s, payload: %s', + queue, + ack_id, + payload['properties'], + ) + if self._is_auto_ack(payload['properties']): + logger.debug('auto acking message ack_id: %s', ack_id) + self._do_ack([ack_id], qdesc.subscription_path) + else: + delivery_info['gcpubsub_message'] = { + 'queue': queue, + 'ack_id': ack_id, + 'message_id': message.message.message_id, + 'subscription_path': qdesc.subscription_path, + } + qdesc.unacked_ids.append(ack_id) + + return payload + + def _is_auto_ack(self, payload_properties: dict): + exchange = payload_properties['delivery_info']['exchange'] + delivery_mode = payload_properties['delivery_mode'] + return ( + delivery_mode == TRANSIENT_DELIVERY_MODE + or exchange in self._fanout_exchanges + ) + + def _get_bulk(self, queue: str, timeout: float): + """Retrieves bulk of messages from a queue.""" + prefixed_queue = self.entity_name(queue) + qdesc = self._queue_cache[prefixed_queue] + max_messages = self._get_max_messages_estimate() + if not max_messages: + raise Empty() + try: + response = self.subscriber.pull( + request={ + 'subscription': qdesc.subscription_path, + 'max_messages': max_messages, + }, + retry=Retry(deadline=self.retry_timeout_seconds), + timeout=timeout or self.wait_time_seconds, + ) + except DeadlineExceeded: + raise Empty() + + received_messages = response.received_messages + if len(received_messages) == 0: + raise Empty() + + auto_ack_ids = [] + ret_payloads = [] + logger.debug( + 'batching %d messages from queue: %s', + len(received_messages), + prefixed_queue, + ) + for message in received_messages: + ack_id = message.ack_id + payload = loads(bytes_to_str(message.message.data)) + delivery_info = payload['properties']['delivery_info'] + delivery_info['gcpubsub_message'] = { + 'queue': prefixed_queue, + 'ack_id': ack_id, + 'message_id': message.message.message_id, + 'subscription_path': qdesc.subscription_path, + } + if self._is_auto_ack(payload['properties']): + auto_ack_ids.append(ack_id) + else: + qdesc.unacked_ids.append(ack_id) + ret_payloads.append(payload) + if auto_ack_ids: + logger.debug('auto acking ack_ids: %s', auto_ack_ids) + self._do_ack(auto_ack_ids, qdesc.subscription_path) + + return queue, ret_payloads + + def _get_max_messages_estimate(self) -> int: + max_allowed = self.qos.can_consume_max_estimate() + max_if_unlimited = self.bulk_max_messages + return max_if_unlimited if max_allowed is None else max_allowed + + def _lookup(self, exchange, routing_key, default=None): + exchange_info = self.state.exchanges.get(exchange, {}) + if not exchange_info: + return super()._lookup(exchange, routing_key, default) + ret = self.typeof(exchange).lookup( + self.get_table(exchange), + exchange, + routing_key, + default, + ) + if ret: + return ret + logger.debug( + 'no queues bound to exchange: %s, binding on the fly', + exchange, + ) + self.queue_bind(exchange, exchange, routing_key) + return [exchange] + + def _size(self, queue: str) -> int: + """Return the number of messages in a queue. + + This is a *rough* estimation, as Pub/Sub doesn't provide + an exact API. + """ + queue = self.entity_name(queue) + if queue not in self._queue_cache: + return 0 + qdesc = self._queue_cache[queue] + result = query.Query( + self.monitor, + self.project_id, + 'pubsub.googleapis.com/subscription/num_undelivered_messages', + end_time=datetime.datetime.now(), + minutes=1, + ).select_resources(subscription_id=qdesc.subscription_id) + + # monitoring API requires the caller to have the monitoring.viewer + # role. Since we can live without the exact number of messages + # in the queue, we can ignore the exception and allow users to + # use the transport without this role. + with suppress(PermissionDenied): + return sum( + content.points[0].value.int64_value for content in result + ) + return -1 + + def basic_ack(self, delivery_tag, multiple=False): + """Acknowledge one message.""" + if multiple: + raise NotImplementedError('multiple acks not implemented') + + delivery_info = self.qos.get(delivery_tag).delivery_info + pubsub_message = delivery_info['gcpubsub_message'] + ack_id = pubsub_message['ack_id'] + queue = pubsub_message['queue'] + logger.debug('ack message. queue: %s ack_id: %s', queue, ack_id) + subscription_path = pubsub_message['subscription_path'] + self._do_ack([ack_id], subscription_path) + qdesc = self._queue_cache[queue] + qdesc.unacked_ids.remove(ack_id) + super().basic_ack(delivery_tag) + + def _do_ack(self, ack_ids: list[str], subscription_path: str): + self.subscriber.acknowledge( + request={"subscription": subscription_path, "ack_ids": ack_ids}, + retry=Retry(deadline=self.retry_timeout_seconds), + ) + + def _purge(self, queue: str): + """Delete all current messages in a queue.""" + queue = self.entity_name(queue) + qdesc = self._queue_cache.get(queue) + if not qdesc: + return + + n = self._size(queue) + self.subscriber.seek( + request={ + "subscription": qdesc.subscription_path, + "time": datetime.datetime.now(), + } + ) + return n + + def _extend_unacked_deadline(self): + thread_id = threading.get_native_id() + logger.info( + 'unacked deadline extension thread: [%s] started', + thread_id, + ) + min_deadline_sleep = self._min_ack_deadline / 2 + sleep_time = max(min_deadline_sleep, self.ack_deadline_seconds / 4) + while not self._stop_extender.wait(sleep_time): + for qdesc in self._queue_cache.values(): + if len(qdesc.unacked_ids) == 0: + logger.debug( + 'thread [%s]: no unacked messages for %s', + thread_id, + qdesc.subscription_path, + ) + continue + logger.debug( + 'thread [%s]: extend ack deadline for %s: %d msgs [%s]', + thread_id, + qdesc.subscription_path, + len(qdesc.unacked_ids), + list(qdesc.unacked_ids), + ) + self.subscriber.modify_ack_deadline( + request={ + "subscription": qdesc.subscription_path, + "ack_ids": list(qdesc.unacked_ids), + "ack_deadline_seconds": self.ack_deadline_seconds, + } + ) + logger.info( + 'unacked deadline extension thread [%s] stopped', thread_id + ) + + def after_reply_message_received(self, queue: str): + queue = self.entity_name(queue) + sub = self.subscriber.subscription_path(self.project_id, queue) + logger.debug( + 'after_reply_message_received: queue: %s, sub: %s', queue, sub + ) + self._tmp_subscriptions.add(sub) + + @cached_property + def subscriber(self): + return SubscriberClient() + + @cached_property + def publisher(self): + return PublisherClient() + + @cached_property + def monitor(self): + return monitoring_v3.MetricServiceClient() + + @property + def conninfo(self): + return self.connection.client + + @property + def transport_options(self): + return self.connection.client.transport_options + + @cached_property + def wait_time_seconds(self): + return self.transport_options.get( + 'wait_time_seconds', self.default_wait_time_seconds + ) + + @cached_property + def retry_timeout_seconds(self): + return self.transport_options.get( + 'retry_timeout_seconds', self.default_retry_timeout_seconds + ) + + @cached_property + def ack_deadline_seconds(self): + return self.transport_options.get( + 'ack_deadline_seconds', self.default_ack_deadline_seconds + ) + + @cached_property + def queue_name_prefix(self): + return self.transport_options.get('queue_name_prefix', 'kombu-') + + @cached_property + def expiration_seconds(self): + return self.transport_options.get( + 'expiration_seconds', self.default_expiration_seconds + ) + + @cached_property + def bulk_max_messages(self): + return self.transport_options.get( + 'bulk_max_messages', self.default_bulk_max_messages + ) + + def close(self): + """Close the channel.""" + logger.debug('closing channel') + while self._tmp_subscriptions: + sub = self._tmp_subscriptions.pop() + with suppress(Exception): + logger.debug('deleting subscription: %s', sub) + self.subscriber.delete_subscription( + request={"subscription": sub} + ) + if not self._n_channels.dec(): + self._stop_extender.set() + Channel._unacked_extender.join() + super().close() + + @staticmethod + def _get_routing_key(message): + routing_key = ( + message['properties'] + .get('delivery_info', {}) + .get('routing_key', '') + ) + return routing_key + + +class Transport(virtual.Transport): + """GCP Pub/Sub transport.""" + + Channel = Channel + + can_parse_url = True + polling_interval = 0.1 + connection_errors = virtual.Transport.connection_errors + ( + pubsub_exceptions.TimeoutError, + ) + channel_errors = ( + virtual.Transport.channel_errors + + ( + publisher_exceptions.FlowControlLimitError, + publisher_exceptions.MessageTooLargeError, + publisher_exceptions.PublishError, + publisher_exceptions.TimeoutError, + publisher_exceptions.PublishToPausedOrderingKeyException, + ) + + (subscriber_exceptions.AcknowledgeError,) + ) + + driver_type = 'gcpubsub' + driver_name = 'pubsub_v1' + + implements = virtual.Transport.implements.extend( + exchange_type=frozenset(['direct', 'fanout']), + ) + + def __init__(self, client, **kwargs): + super().__init__(client, **kwargs) + self._pool = ThreadPoolExecutor() + self._get_bulk_future_to_queue: dict[Future, str] = dict() + + def driver_version(self): + return package_version.__version__ + + @staticmethod + def parse_uri(uri: str) -> str: + # URL like: + # gcpubsub://projects/project-name + + project = uri.split('gcpubsub://projects/')[1] + return project.strip('/') + + @classmethod + def as_uri(self, uri: str, include_password=False, mask='**') -> str: + return uri or 'gcpubsub://' + + def drain_events(self, connection, timeout=None): + time_start = monotonic() + polling_interval = self.polling_interval + if timeout and polling_interval and polling_interval > timeout: + polling_interval = timeout + while 1: + try: + self._drain_from_active_queues(timeout=timeout) + except Empty: + if timeout and monotonic() - time_start >= timeout: + raise socket_timeout() + if polling_interval: + sleep(polling_interval) + else: + break + + def _drain_from_active_queues(self, timeout): + # cleanup empty requests from prev run + self._rm_empty_bulk_requests() + + # submit new requests for all active queues + # longer timeout means less frequent polling + # and more messages in a single bulk + self._submit_get_bulk_requests(timeout=10) + + done, _ = wait( + self._get_bulk_future_to_queue, + timeout=timeout, + return_when=FIRST_COMPLETED, + ) + empty = {f for f in done if f.exception()} + done -= empty + for f in empty: + self._get_bulk_future_to_queue.pop(f, None) + + if not done: + raise Empty() + + logger.debug('got %d done get_bulk tasks', len(done)) + for f in done: + queue, payloads = f.result() + for payload in payloads: + logger.debug('consuming message from queue: %s', queue) + if queue not in self._callbacks: + logger.warning( + 'Message for queue %s without consumers', queue + ) + continue + self._deliver(payload, queue) + self._get_bulk_future_to_queue.pop(f, None) + + def _rm_empty_bulk_requests(self): + empty = { + f + for f in self._get_bulk_future_to_queue + if f.done() and f.exception() + } + for f in empty: + self._get_bulk_future_to_queue.pop(f, None) + + def _submit_get_bulk_requests(self, timeout): + queues_with_submitted_get_bulk = set( + self._get_bulk_future_to_queue.values() + ) + + for channel in self.channels: + for queue in channel._active_queues: + if queue in queues_with_submitted_get_bulk: + continue + future = self._pool.submit(channel._get_bulk, queue, timeout) + self._get_bulk_future_to_queue[future] = queue diff --git a/requirements/extras/gcpubsub.txt b/requirements/extras/gcpubsub.txt new file mode 100644 index 00000000..6dfee275 --- /dev/null +++ b/requirements/extras/gcpubsub.txt @@ -0,0 +1,3 @@ +google-cloud-pubsub>=2.18.4 +google-cloud-monitoring>=2.16.0 +grpcio==1.66.2 diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt index f9ffaf95..8882410e 100644 --- a/requirements/test-ci.txt +++ b/requirements/test-ci.txt @@ -16,3 +16,4 @@ urllib3>=1.26.16; sys_platform != 'win32' -r extras/zstd.txt -r extras/sqlalchemy.txt -r extras/etcd.txt +-r extras/gcpubsub.txt diff --git a/setup.py b/setup.py index da1ac730..f4cb2579 100644 --- a/setup.py +++ b/setup.py @@ -103,6 +103,7 @@ setup( 'redis': extras('redis.txt'), 'mongodb': extras('mongodb.txt'), 'sqs': extras('sqs.txt'), + 'gcpubsub': extras('gcpubsub.txt'), 'zookeeper': extras('zookeeper.txt'), 'sqlalchemy': extras('sqlalchemy.txt'), 'librabbitmq': extras('librabbitmq.txt'), diff --git a/t/unit/transport/test_gcpubsub.py b/t/unit/transport/test_gcpubsub.py new file mode 100644 index 00000000..5e60336b --- /dev/null +++ b/t/unit/transport/test_gcpubsub.py @@ -0,0 +1,793 @@ +from __future__ import annotations + +from concurrent.futures import Future +from datetime import datetime +from queue import Empty +from unittest.mock import MagicMock, call, patch + +import pytest +from _socket import timeout as socket_timeout +from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded, + PermissionDenied) + +from kombu.transport.gcpubsub import (AtomicCounter, Channel, QueueDescriptor, + Transport, UnackedIds) + + +class test_UnackedIds: + def setup_method(self): + self.unacked_ids = UnackedIds() + + def test_append(self): + self.unacked_ids.append('test_id') + assert self.unacked_ids[0] == 'test_id' + + def test_extend(self): + self.unacked_ids.extend(['test_id1', 'test_id2']) + assert self.unacked_ids[0] == 'test_id1' + assert self.unacked_ids[1] == 'test_id2' + + def test_pop(self): + self.unacked_ids.append('test_id') + popped_id = self.unacked_ids.pop() + assert popped_id == 'test_id' + assert len(self.unacked_ids) == 0 + + def test_remove(self): + self.unacked_ids.append('test_id') + self.unacked_ids.remove('test_id') + assert len(self.unacked_ids) == 0 + + def test_len(self): + self.unacked_ids.append('test_id') + assert len(self.unacked_ids) == 1 + + def test_getitem(self): + self.unacked_ids.append('test_id') + assert self.unacked_ids[0] == 'test_id' + + +class test_AtomicCounter: + def setup_method(self): + self.counter = AtomicCounter() + + def test_inc(self): + assert self.counter.inc() == 1 + assert self.counter.inc(5) == 6 + + def test_dec(self): + self.counter.inc(5) + assert self.counter.dec() == 4 + assert self.counter.dec(2) == 2 + + def test_get(self): + self.counter.inc(7) + assert self.counter.get() == 7 + + +@pytest.fixture +def channel(): + with patch.object(Channel, '__init__', lambda self: None): + channel = Channel() + channel.connection = MagicMock() + channel.queue_name_prefix = "kombu-" + channel.project_id = "test_project" + channel._queue_cache = {} + channel._n_channels = MagicMock() + channel._stop_extender = MagicMock() + channel.subscriber = MagicMock() + channel.publisher = MagicMock() + channel.closed = False + with patch.object( + Channel, 'conninfo', new_callable=MagicMock + ), patch.object( + Channel, 'transport_options', new_callable=MagicMock + ), patch.object( + Channel, 'qos', new_callable=MagicMock + ): + yield channel + + +class test_Channel: + @patch('kombu.transport.gcpubsub.ThreadPoolExecutor') + @patch('kombu.transport.gcpubsub.threading.Event') + @patch('kombu.transport.gcpubsub.threading.Thread') + @patch( + 'kombu.transport.gcpubsub.Channel._get_free_channel_id', + return_value=1, + ) + @patch( + 'kombu.transport.gcpubsub.Channel._n_channels.inc', + return_value=1, + ) + def test_channel_init( + self, + n_channels_in_mock, + channel_id_mock, + mock_thread, + mock_event, + mock_executor, + ): + mock_connection = MagicMock() + ch = Channel(mock_connection) + ch._n_channels.inc.assert_called_once() + mock_thread.assert_called_once_with( + target=ch._extend_unacked_deadline, + daemon=True, + ) + mock_thread.return_value.start.assert_called_once() + + def test_entity_name(self, channel): + name = "test_queue" + result = channel.entity_name(name) + assert result == "kombu-test_queue" + + @patch('kombu.transport.gcpubsub.uuid3', return_value='uuid') + @patch('kombu.transport.gcpubsub.gethostname', return_value='hostname') + @patch('kombu.transport.gcpubsub.getpid', return_value=1234) + def test_queue_bind_direct( + self, mock_pid, mock_hostname, mock_uuid, channel + ): + exchange = 'direct' + routing_key = 'test_routing_key' + pattern = 'test_pattern' + queue = 'test_queue' + subscription_path = 'projects/project-id/subscriptions/test_queue' + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel._create_topic = MagicMock(return_value='topic_path') + channel._create_subscription = MagicMock() + + # Mock the state and exchange type + mock_connection = MagicMock(name='mock_connection') + channel.connection = mock_connection + channel.state.exchanges = {exchange: {'type': 'direct'}} + mock_exchange = MagicMock(name='mock_exchange', type='direct') + channel.exchange_types = {'direct': mock_exchange} + + channel._queue_bind(exchange, routing_key, pattern, queue) + + channel._create_topic.assert_called_once_with( + channel.project_id, exchange, channel.expiration_seconds + ) + channel._create_subscription.assert_called_once_with( + topic_path='topic_path', + subscription_path=subscription_path, + filter_args={'filter': f'attributes.routing_key="{routing_key}"'}, + msg_retention=channel.expiration_seconds, + ) + assert channel.entity_name(queue) in channel._queue_cache + + @patch('kombu.transport.gcpubsub.uuid3', return_value='uuid') + @patch('kombu.transport.gcpubsub.gethostname', return_value='hostname') + @patch('kombu.transport.gcpubsub.getpid', return_value=1234) + def test_queue_bind_fanout( + self, mock_pid, mock_hostname, mock_uuid, channel + ): + exchange = 'test_exchange' + routing_key = 'test_routing_key' + pattern = 'test_pattern' + queue = 'test_queue' + uniq_sub_name = 'test_queue-uuid' + subscription_path = ( + f'projects/project-id/subscriptions/{uniq_sub_name}' + ) + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel._create_topic = MagicMock(return_value='topic_path') + channel._create_subscription = MagicMock() + + # Mock the state and exchange type + mock_connection = MagicMock(name='mock_connection') + channel.connection = mock_connection + channel.state.exchanges = {exchange: {'type': 'fanout'}} + mock_exchange = MagicMock(name='mock_exchange', type='fanout') + channel.exchange_types = {'fanout': mock_exchange} + + channel._queue_bind(exchange, routing_key, pattern, queue) + + channel._create_topic.assert_called_once_with( + channel.project_id, exchange, 600 + ) + channel._create_subscription.assert_called_once_with( + topic_path='topic_path', + subscription_path=subscription_path, + filter_args={}, + msg_retention=600, + ) + assert channel.entity_name(queue) in channel._queue_cache + assert subscription_path in channel._tmp_subscriptions + assert exchange in channel._fanout_exchanges + + def test_queue_bind_not_implemented(self, channel): + exchange = 'test_exchange' + routing_key = 'test_routing_key' + pattern = 'test_pattern' + queue = 'test_queue' + channel.typeof = MagicMock(return_value=MagicMock(type='unsupported')) + + with pytest.raises(NotImplementedError): + channel._queue_bind(exchange, routing_key, pattern, queue) + + def test_create_topic(self, channel): + channel.project_id = "project_id" + topic_id = "topic_id" + channel._is_topic_exists = MagicMock(return_value=False) + channel.publisher.topic_path = MagicMock(return_value="topic_path") + channel.publisher.create_topic = MagicMock() + result = channel._create_topic(channel.project_id, topic_id) + assert result == "topic_path" + channel.publisher.create_topic.assert_called_once() + + channel._create_topic( + channel.project_id, topic_id, message_retention_duration=10 + ) + assert ( + dict( + request={ + 'name': 'topic_path', + 'message_retention_duration': '10s', + } + ) + in channel.publisher.create_topic.call_args + ) + channel.publisher.create_topic.side_effect = AlreadyExists( + "test_error" + ) + channel._create_topic( + channel.project_id, topic_id, message_retention_duration=10 + ) + + def test_is_topic_exists(self, channel): + topic_path = "projects/project-id/topics/test_topic" + mock_topic = MagicMock() + mock_topic.name = topic_path + channel.publisher.list_topics.return_value = [mock_topic] + + result = channel._is_topic_exists(topic_path) + + assert result is True + channel.publisher.list_topics.assert_called_once_with( + request={"project": f'projects/{channel.project_id}'} + ) + + def test_is_topic_not_exists(self, channel): + topic_path = "projects/project-id/topics/test_topic" + channel.publisher.list_topics.return_value = [] + + result = channel._is_topic_exists(topic_path) + + assert result is False + channel.publisher.list_topics.assert_called_once_with( + request={"project": f'projects/{channel.project_id}'} + ) + + def test_create_subscription(self, channel): + channel.project_id = "project_id" + topic_id = "topic_id" + subscription_path = "subscription_path" + topic_path = "topic_path" + channel.subscriber.subscription_path = MagicMock( + return_value=subscription_path + ) + channel.publisher.topic_path = MagicMock(return_value=topic_path) + channel.subscriber.create_subscription = MagicMock() + result = channel._create_subscription( + project_id=channel.project_id, + topic_id=topic_id, + subscription_path=subscription_path, + topic_path=topic_path, + ) + assert result == subscription_path + channel.subscriber.create_subscription.assert_called_once() + + def test_delete(self, channel): + queue = "test_queue" + subscription_path = "projects/project-id/subscriptions/test_queue" + qdesc = QueueDescriptor( + name=queue, + topic_path="projects/project-id/topics/test_topic", + subscription_id=queue, + subscription_path=subscription_path, + ) + channel.subscriber = MagicMock() + channel._queue_cache[channel.entity_name(queue)] = qdesc + + channel._delete(queue) + + channel.subscriber.delete_subscription.assert_called_once_with( + request={"subscription": subscription_path} + ) + assert queue not in channel._queue_cache + + def test_put(self, channel): + queue = "test_queue" + message = { + "properties": {"delivery_info": {"routing_key": "test_key"}} + } + channel.entity_name = MagicMock(return_value=queue) + channel._queue_cache[channel.entity_name(queue)] = QueueDescriptor( + name=queue, + topic_path="topic_path", + subscription_id=queue, + subscription_path="subscription_path", + ) + channel._get_routing_key = MagicMock(return_value="test_key") + channel.publisher.publish = MagicMock() + channel._put(queue, message) + channel.publisher.publish.assert_called_once() + + def test_put_fanout(self, channel): + exchange = "test_exchange" + message = { + "properties": {"delivery_info": {"routing_key": "test_key"}} + } + routing_key = "test_key" + + channel._lookup = MagicMock() + channel.publisher.topic_path = MagicMock(return_value="topic_path") + channel.publisher.publish = MagicMock() + + channel._put_fanout(exchange, message, routing_key) + + channel._lookup.assert_called_once_with(exchange, routing_key) + channel.publisher.topic_path.assert_called_once_with( + channel.project_id, exchange + ) + assert 'topic_path', ( + b'{"properties": {"delivery_info": {"routing_key": "test_key"}}}' + in channel.publisher.publish.call_args + ) + + def test_get(self, channel): + queue = "test_queue" + channel.entity_name = MagicMock(return_value=queue) + channel._queue_cache[queue] = QueueDescriptor( + name=queue, + topic_path="topic_path", + subscription_id=queue, + subscription_path="subscription_path", + ) + channel.subscriber.pull = MagicMock( + return_value=MagicMock( + received_messages=[ + MagicMock( + ack_id="ack_id", + message=MagicMock( + data=b'{"properties": ' + b'{"delivery_info": ' + b'{"exchange": "exchange"},"delivery_mode": 1}}' + ), + ) + ] + ) + ) + channel.subscriber.acknowledge = MagicMock() + payload = channel._get(queue) + assert ( + payload["properties"]["delivery_info"]["exchange"] == "exchange" + ) + channel.subscriber.pull.side_effect = DeadlineExceeded("test_error") + with pytest.raises(Empty): + channel._get(queue) + + def test_get_bulk(self, channel): + queue = "test_queue" + subscription_path = "projects/project-id/subscriptions/test_queue" + qdesc = QueueDescriptor( + name=queue, + topic_path="projects/project-id/topics/test_topic", + subscription_id=queue, + subscription_path=subscription_path, + ) + channel._queue_cache[channel.entity_name(queue)] = qdesc + + data = b'{"properties": {"delivery_info": {"exchange": "exchange"}}}' + received_message = MagicMock( + ack_id="ack_id", + message=MagicMock(data=data), + ) + channel.subscriber.pull = MagicMock( + return_value=MagicMock(received_messages=[received_message]) + ) + channel.bulk_max_messages = 10 + channel._is_auto_ack = MagicMock(return_value=True) + channel._do_ack = MagicMock() + channel.qos.can_consume_max_estimate = MagicMock(return_value=None) + queue, payloads = channel._get_bulk(queue, timeout=10) + + assert len(payloads) == 1 + assert ( + payloads[0]["properties"]["delivery_info"]["exchange"] + == "exchange" + ) + channel._do_ack.assert_called_once_with(["ack_id"], subscription_path) + + channel.subscriber.pull.side_effect = DeadlineExceeded("test_error") + with pytest.raises(Empty): + channel._get_bulk(queue, timeout=10) + + def test_lookup(self, channel): + exchange = "test_exchange" + routing_key = "test_key" + default = None + + channel.connection = MagicMock() + channel.state.exchanges = {exchange: {"type": "direct"}} + channel.typeof = MagicMock( + return_value=MagicMock(lookup=MagicMock(return_value=["queue1"])) + ) + channel.get_table = MagicMock(return_value="table") + + result = channel._lookup(exchange, routing_key, default) + + channel.typeof.return_value.lookup.assert_called_once_with( + "table", exchange, routing_key, default + ) + assert result == ["queue1"] + + # Test the case where no queues are bound to the exchange + channel.typeof.return_value.lookup.return_value = None + channel.queue_bind = MagicMock() + + result = channel._lookup(exchange, routing_key, default) + + channel.queue_bind.assert_called_once_with( + exchange, exchange, routing_key + ) + assert result == [exchange] + + @patch('kombu.transport.gcpubsub.monitoring_v3') + @patch('kombu.transport.gcpubsub.query.Query') + def test_size(self, mock_query, mock_monitor, channel): + queue = "test_queue" + subscription_id = "test_subscription" + qdesc = QueueDescriptor( + name=queue, + topic_path="projects/project-id/topics/test_topic", + subscription_id=subscription_id, + subscription_path="projects/project-id/subscriptions/test_subscription", # E501 + ) + channel._queue_cache[channel.entity_name(queue)] = qdesc + + mock_query_result = MagicMock() + mock_query_result.select_resources.return_value = [ + MagicMock(points=[MagicMock(value=MagicMock(int64_value=5))]) + ] + mock_query.return_value = mock_query_result + size = channel._size(queue) + assert size == 5 + + # Test the case where the queue is not in the cache + size = channel._size("non_existent_queue") + assert size == 0 + + # Test the case where the query raises PermissionDenied + mock_item = MagicMock() + mock_item.points.__getitem__.side_effect = PermissionDenied( + 'test_error' + ) + mock_query_result.select_resources.return_value = [mock_item] + size = channel._size(queue) + assert size == -1 + + def test_basic_ack(self, channel): + delivery_tag = "test_delivery_tag" + ack_id = "test_ack_id" + queue = "test_queue" + subscription_path = ( + "projects/project-id/subscriptions/test_subscription" + ) + qdesc = QueueDescriptor( + name=queue, + topic_path="projects/project-id/topics/test_topic", + subscription_id="test_subscription", + subscription_path=subscription_path, + ) + channel._queue_cache[queue] = qdesc + + delivery_info = { + 'gcpubsub_message': { + 'queue': queue, + 'ack_id': ack_id, + 'subscription_path': subscription_path, + } + } + channel.qos.get = MagicMock( + return_value=MagicMock(delivery_info=delivery_info) + ) + channel._do_ack = MagicMock() + + channel.basic_ack(delivery_tag) + + channel._do_ack.assert_called_once_with([ack_id], subscription_path) + assert ack_id not in qdesc.unacked_ids + + def test_do_ack(self, channel): + ack_ids = ["ack_id1", "ack_id2"] + subscription_path = ( + "projects/project-id/subscriptions/test_subscription" + ) + channel.subscriber = MagicMock() + + channel._do_ack(ack_ids, subscription_path) + assert subscription_path, ( + ack_ids in channel.subscriber.acknowledge.call_args + ) + + def test_purge(self, channel): + queue = "test_queue" + subscription_path = f"projects/project-id/subscriptions/{queue}" + qdesc = QueueDescriptor( + name=queue, + topic_path="projects/project-id/topics/test_topic", + subscription_id="test_subscription", + subscription_path=subscription_path, + ) + channel._queue_cache[channel.entity_name(queue)] = qdesc + channel.subscriber = MagicMock() + + with patch.object(channel, '_size', return_value=10), patch( + 'kombu.transport.gcpubsub.datetime.datetime' + ) as dt_mock: + dt_mock.now.return_value = datetime(2021, 1, 1) + result = channel._purge(queue) + assert result == 10 + channel.subscriber.seek.assert_called_once_with( + request={ + "subscription": subscription_path, + "time": datetime(2021, 1, 1), + } + ) + + # Test the case where the queue is not in the cache + result = channel._purge("non_existent_queue") + assert result is None + + def test_extend_unacked_deadline(self, channel): + queue = "test_queue" + subscription_path = ( + "projects/project-id/subscriptions/test_subscription" + ) + ack_ids = ["ack_id1", "ack_id2"] + qdesc = QueueDescriptor( + name=queue, + topic_path="projects/project-id/topics/test_topic", + subscription_id="test_subscription", + subscription_path=subscription_path, + ) + channel.transport_options = {"ack_deadline_seconds": 240} + channel._queue_cache[channel.entity_name(queue)] = qdesc + qdesc.unacked_ids.extend(ack_ids) + + channel._stop_extender.wait = MagicMock(side_effect=[False, True]) + channel.subscriber.modify_ack_deadline = MagicMock() + + channel._extend_unacked_deadline() + + channel.subscriber.modify_ack_deadline.assert_called_once_with( + request={ + "subscription": subscription_path, + "ack_ids": ack_ids, + "ack_deadline_seconds": 240, + } + ) + for _ in ack_ids: + qdesc.unacked_ids.pop() + channel._stop_extender.wait = MagicMock(side_effect=[False, True]) + modify_ack_deadline_calls = ( + channel.subscriber.modify_ack_deadline.call_count + ) + channel._extend_unacked_deadline() + assert ( + channel.subscriber.modify_ack_deadline.call_count + == modify_ack_deadline_calls + ) + + def test_after_reply_message_received(self, channel): + queue = 'test-queue' + subscription_path = f'projects/test-project/subscriptions/{queue}' + + channel.subscriber.subscription_path.return_value = subscription_path + channel.after_reply_message_received(queue) + + # Check that the subscription path is added to _tmp_subscriptions + assert subscription_path in channel._tmp_subscriptions + + def test_subscriber(self, channel): + assert channel.subscriber + + def test_publisher(self, channel): + assert channel.publisher + + def test_transport_options(self, channel): + assert channel.transport_options + + def test_bulk_max_messages_default(self, channel): + assert channel.bulk_max_messages == channel.transport_options.get( + 'bulk_max_messages' + ) + + def test_close(self, channel): + channel._tmp_subscriptions = {'sub1', 'sub2'} + channel._n_channels.dec.return_value = 0 + + with patch.object( + Channel._unacked_extender, 'join' + ) as mock_join, patch( + 'kombu.transport.virtual.Channel.close' + ) as mock_super_close: + channel.close() + + channel.subscriber.delete_subscription.assert_has_calls( + [ + call(request={"subscription": 'sub1'}), + call(request={"subscription": 'sub2'}), + ], + any_order=True, + ) + channel._stop_extender.set.assert_called_once() + mock_join.assert_called_once() + mock_super_close.assert_called_once() + + +@pytest.fixture +def transport(): + return Transport(client=MagicMock()) + + +class test_Transport: + def test_driver_version(self, transport): + assert transport.driver_version() + + def test_as_uri(self, transport): + result = transport.as_uri('gcpubsub://') + assert result == 'gcpubsub://' + + def test_drain_events_timeout(self, transport): + transport.polling_interval = 4 + with patch.object( + transport, '_drain_from_active_queues', side_effect=Empty + ), patch( + 'kombu.transport.gcpubsub.monotonic', + side_effect=[0, 1, 2, 3, 4, 5], + ), patch( + 'kombu.transport.gcpubsub.sleep' + ) as mock_sleep: + with pytest.raises(socket_timeout): + transport.drain_events(None, timeout=3) + mock_sleep.assert_called() + + def test_drain_events_no_timeout(self, transport): + with patch.object( + transport, '_drain_from_active_queues', side_effect=[Empty, None] + ), patch( + 'kombu.transport.gcpubsub.monotonic', side_effect=[0, 1] + ), patch( + 'kombu.transport.gcpubsub.sleep' + ) as mock_sleep: + transport.drain_events(None, timeout=None) + mock_sleep.assert_called() + + def test_drain_events_polling_interval(self, transport): + transport.polling_interval = 2 + with patch.object( + transport, '_drain_from_active_queues', side_effect=[Empty, None] + ), patch( + 'kombu.transport.gcpubsub.monotonic', side_effect=[0, 1, 2] + ), patch( + 'kombu.transport.gcpubsub.sleep' + ) as mock_sleep: + transport.drain_events(None, timeout=5) + mock_sleep.assert_called_with(2) + + def test_drain_from_active_queues_empty(self, transport): + with patch.object( + transport, '_rm_empty_bulk_requests' + ) as mock_rm_empty, patch.object( + transport, '_submit_get_bulk_requests' + ) as mock_submit, patch( + 'kombu.transport.gcpubsub.wait', return_value=(set(), set()) + ) as mock_wait: + with pytest.raises(Empty): + transport._drain_from_active_queues(timeout=10) + mock_rm_empty.assert_called_once() + mock_submit.assert_called_once_with(timeout=10) + mock_wait.assert_called_once() + + def test_drain_from_active_queues_done(self, transport): + future = Future() + future.set_result(('queue', [{'properties': {'delivery_info': {}}}])) + + with patch.object( + transport, '_rm_empty_bulk_requests' + ) as mock_rm_empty, patch.object( + transport, '_submit_get_bulk_requests' + ) as mock_submit, patch( + 'kombu.transport.gcpubsub.wait', return_value=({future}, set()) + ) as mock_wait, patch.object( + transport, '_deliver' + ) as mock_deliver: + transport._callbacks = {'queue'} + transport._drain_from_active_queues(timeout=10) + mock_rm_empty.assert_called_once() + mock_submit.assert_called_once_with(timeout=10) + mock_wait.assert_called_once() + mock_deliver.assert_called_once_with( + {'properties': {'delivery_info': {}}}, 'queue' + ) + + mock_deliver_call_count = mock_deliver.call_count + transport._callbacks = {} + transport._drain_from_active_queues(timeout=10) + assert mock_deliver_call_count == mock_deliver.call_count + + def test_drain_from_active_queues_exception(self, transport): + future = Future() + future.set_exception(Exception("Test exception")) + + with patch.object( + transport, '_rm_empty_bulk_requests' + ) as mock_rm_empty, patch.object( + transport, '_submit_get_bulk_requests' + ) as mock_submit, patch( + 'kombu.transport.gcpubsub.wait', return_value=({future}, set()) + ) as mock_wait: + with pytest.raises(Empty): + transport._drain_from_active_queues(timeout=10) + mock_rm_empty.assert_called_once() + mock_submit.assert_called_once_with(timeout=10) + mock_wait.assert_called_once() + + def test_rm_empty_bulk_requests(self, transport): + # Create futures with exceptions to simulate empty requests + future_with_exception = Future() + future_with_exception.set_exception(Exception("Test exception")) + + transport._get_bulk_future_to_queue = { + future_with_exception: 'queue1', + } + + transport._rm_empty_bulk_requests() + + # Assert that the future with exception is removed + assert ( + future_with_exception not in transport._get_bulk_future_to_queue + ) + + def test_submit_get_bulk_requests(self, transport): + channel_mock = MagicMock(spec=Channel) + channel_mock._active_queues = ['queue1', 'queue2'] + transport.channels = [channel_mock] + + with patch.object( + transport._pool, 'submit', return_value=MagicMock() + ) as mock_submit: + transport._submit_get_bulk_requests(timeout=10) + + # Check that submit was called twice, once for each queue + assert mock_submit.call_count == 2 + mock_submit.assert_any_call(channel_mock._get_bulk, 'queue1', 10) + mock_submit.assert_any_call(channel_mock._get_bulk, 'queue2', 10) + + def test_submit_get_bulk_requests_with_existing_futures(self, transport): + channel_mock = MagicMock(spec=Channel) + channel_mock._active_queues = ['queue1', 'queue2'] + transport.channels = [channel_mock] + + # Simulate existing futures + future_mock = MagicMock() + transport._get_bulk_future_to_queue = {future_mock: 'queue1'} + + with patch.object( + transport._pool, 'submit', return_value=MagicMock() + ) as mock_submit: + transport._submit_get_bulk_requests(timeout=10) + + # Check that submit was called only once for the new queue + assert mock_submit.call_count == 1 + mock_submit.assert_called_with( + channel_mock._get_bulk, 'queue2', 10 + ) diff --git a/tox.ini b/tox.ini index b1259e0e..7b45b9b4 100644 --- a/tox.ini +++ b/tox.ini @@ -25,7 +25,9 @@ python = [testenv] sitepackages = False -setenv = C_DEBUG_TEST = 1 +setenv = + C_DEBUG_TEST = 1 + PIP_EXTRA_INDEX_URL=https://celery.github.io/celery-wheelhouse/repo/simple/ passenv = DISTUTILS_USE_SDK deps=