mirror of https://github.com/celery/kombu.git
Add support for Google Pub/Sub as transport broker (#2147)
* Add support for Google Pub/Sub as transport broker
* Add tests
* Add docs
* flake8
* flake8
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Fix future import
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Add missing test requirements
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Add monitoring dependency
* Fix test for python3.8
* Mock better google's monitoring api
* Flake8
* Add refdoc
* Add extra url to workaround pypy grpcio
* Add extra index url in tox for grpcio/pypy support
* Revert "Add extra url to workaround pypy grpcio"
This reverts commit dfde4d523c
.
* pin grpcio version to match extra_index wheel
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Reduce poll calls if qos denies msg rx
---------
Co-authored-by: Haim Daniel <haimdaniel@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tomer Nosrati <tomer.nosrati@gmail.com>
This commit is contained in:
parent
5a88a28f31
commit
2f58823312
|
@ -0,0 +1,24 @@
|
|||
==============================================================
|
||||
Google Cloud Pub/Sub Transport - ``kombu.transport.gcpubsub``
|
||||
==============================================================
|
||||
|
||||
.. currentmodule:: kombu.transport.gcpubsub
|
||||
|
||||
.. automodule:: kombu.transport.gcpubsub
|
||||
|
||||
.. contents::
|
||||
:local:
|
||||
|
||||
Transport
|
||||
---------
|
||||
|
||||
.. autoclass:: Transport
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
Channel
|
||||
-------
|
||||
|
||||
.. autoclass:: Channel
|
||||
:members:
|
||||
:undoc-members:
|
|
@ -43,7 +43,8 @@ TRANSPORT_ALIASES = {
|
|||
'etcd': 'kombu.transport.etcd:Transport',
|
||||
'azurestoragequeues': 'kombu.transport.azurestoragequeues:Transport',
|
||||
'azureservicebus': 'kombu.transport.azureservicebus:Transport',
|
||||
'pyro': 'kombu.transport.pyro:Transport'
|
||||
'pyro': 'kombu.transport.pyro:Transport',
|
||||
'gcpubsub': 'kombu.transport.gcpubsub:Transport',
|
||||
}
|
||||
|
||||
_transport_cache = {}
|
||||
|
|
|
@ -0,0 +1,810 @@
|
|||
"""GCP Pub/Sub transport module for kombu.
|
||||
|
||||
More information about GCP Pub/Sub:
|
||||
https://cloud.google.com/pubsub
|
||||
|
||||
Features
|
||||
========
|
||||
* Type: Virtual
|
||||
* Supports Direct: Yes
|
||||
* Supports Topic: No
|
||||
* Supports Fanout: Yes
|
||||
* Supports Priority: No
|
||||
* Supports TTL: No
|
||||
|
||||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
gcpubsub://projects/project-name
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
* ``queue_name_prefix``: (str) Prefix for queue names.
|
||||
* ``ack_deadline_seconds``: (int) The maximum time after receiving a message
|
||||
and acknowledging it before pub/sub redelivers the message.
|
||||
* ``expiration_seconds``: (int) Subscriptions without any subscriber
|
||||
activity or changes made to their properties are removed after this period.
|
||||
Examples of subscriber activities include open connections,
|
||||
active pulls, or successful pushes.
|
||||
* ``wait_time_seconds``: (int) The maximum time to wait for new messages.
|
||||
Defaults to 10.
|
||||
* ``retry_timeout_seconds``: (int) The maximum time to wait before retrying.
|
||||
* ``bulk_max_messages``: (int) The maximum number of messages to pull in bulk.
|
||||
Defaults to 32.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import datetime
|
||||
import string
|
||||
import threading
|
||||
from concurrent.futures import (FIRST_COMPLETED, Future, ThreadPoolExecutor,
|
||||
wait)
|
||||
from contextlib import suppress
|
||||
from os import getpid
|
||||
from queue import Empty
|
||||
from threading import Lock
|
||||
from time import monotonic, sleep
|
||||
from uuid import NAMESPACE_OID, uuid3
|
||||
|
||||
from _socket import gethostname
|
||||
from _socket import timeout as socket_timeout
|
||||
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
|
||||
PermissionDenied)
|
||||
from google.api_core.retry import Retry
|
||||
from google.cloud import monitoring_v3
|
||||
from google.cloud.monitoring_v3 import query
|
||||
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
|
||||
from google.cloud.pubsub_v1 import exceptions as pubsub_exceptions
|
||||
from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions
|
||||
from google.cloud.pubsub_v1.subscriber import \
|
||||
exceptions as subscriber_exceptions
|
||||
from google.pubsub_v1 import gapic_version as package_version
|
||||
|
||||
from kombu.entity import TRANSIENT_DELIVERY_MODE
|
||||
from kombu.log import get_logger
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
||||
from . import virtual
|
||||
|
||||
logger = get_logger('kombu.transport.gcpubsub')
|
||||
|
||||
# dots are replaced by dash, all other punctuation replaced by underscore.
|
||||
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
|
||||
CHARS_REPLACE_TABLE = {
|
||||
ord('.'): ord('-'),
|
||||
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE},
|
||||
}
|
||||
|
||||
|
||||
class UnackedIds:
|
||||
"""Threadsafe list of ack_ids."""
|
||||
|
||||
def __init__(self):
|
||||
self._list = []
|
||||
self._lock = Lock()
|
||||
|
||||
def append(self, val):
|
||||
# append is atomic
|
||||
self._list.append(val)
|
||||
|
||||
def extend(self, vals: list):
|
||||
# extend is atomic
|
||||
self._list.extend(vals)
|
||||
|
||||
def pop(self, index=-1):
|
||||
with self._lock:
|
||||
return self._list.pop(index)
|
||||
|
||||
def remove(self, val):
|
||||
with self._lock, suppress(ValueError):
|
||||
self._list.remove(val)
|
||||
|
||||
def __len__(self):
|
||||
with self._lock:
|
||||
return len(self._list)
|
||||
|
||||
def __getitem__(self, item):
|
||||
# getitem is atomic
|
||||
return self._list[item]
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
"""Threadsafe counter.
|
||||
|
||||
Returns the value after inc/dec operations.
|
||||
"""
|
||||
|
||||
def __init__(self, initial=0):
|
||||
self._value = initial
|
||||
self._lock = Lock()
|
||||
|
||||
def inc(self, n=1):
|
||||
with self._lock:
|
||||
self._value += n
|
||||
return self._value
|
||||
|
||||
def dec(self, n=1):
|
||||
with self._lock:
|
||||
self._value -= n
|
||||
return self._value
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
return self._value
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class QueueDescriptor:
|
||||
"""Pub/Sub queue descriptor."""
|
||||
|
||||
name: str
|
||||
topic_path: str # projects/{project_id}/topics/{topic_id}
|
||||
subscription_id: str
|
||||
subscription_path: str # projects/{project_id}/subscriptions/{subscription_id}
|
||||
unacked_ids: UnackedIds = dataclasses.field(default_factory=UnackedIds)
|
||||
|
||||
|
||||
class Channel(virtual.Channel):
|
||||
"""GCP Pub/Sub channel."""
|
||||
|
||||
supports_fanout = True
|
||||
do_restore = False # pub/sub does that for us
|
||||
default_wait_time_seconds = 10
|
||||
default_ack_deadline_seconds = 240
|
||||
default_expiration_seconds = 86400
|
||||
default_retry_timeout_seconds = 300
|
||||
default_bulk_max_messages = 32
|
||||
|
||||
_min_ack_deadline = 10
|
||||
_fanout_exchanges = set()
|
||||
_unacked_extender: threading.Thread = None
|
||||
_stop_extender = threading.Event()
|
||||
_n_channels = AtomicCounter()
|
||||
_queue_cache: dict[str, QueueDescriptor] = {}
|
||||
_tmp_subscriptions: set[str] = set()
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pool = ThreadPoolExecutor()
|
||||
logger.info('new GCP pub/sub channel: %s', self.conninfo.hostname)
|
||||
|
||||
self.project_id = Transport.parse_uri(self.conninfo.hostname)
|
||||
if self._n_channels.inc() == 1:
|
||||
Channel._unacked_extender = threading.Thread(
|
||||
target=self._extend_unacked_deadline,
|
||||
daemon=True,
|
||||
)
|
||||
self._stop_extender.clear()
|
||||
Channel._unacked_extender.start()
|
||||
|
||||
def entity_name(self, name: str, table=CHARS_REPLACE_TABLE) -> str:
|
||||
"""Format AMQP queue name into a valid Pub/Sub queue name."""
|
||||
if not name.startswith(self.queue_name_prefix):
|
||||
name = self.queue_name_prefix + name
|
||||
|
||||
return str(safe_str(name)).translate(table)
|
||||
|
||||
def _queue_bind(self, exchange, routing_key, pattern, queue):
|
||||
exchange_type = self.typeof(exchange).type
|
||||
queue = self.entity_name(queue)
|
||||
logger.debug(
|
||||
'binding queue: %s to %s exchange: %s with routing_key: %s',
|
||||
queue,
|
||||
exchange_type,
|
||||
exchange,
|
||||
routing_key,
|
||||
)
|
||||
|
||||
filter_args = {}
|
||||
if exchange_type == 'direct':
|
||||
# Direct exchange is implemented as a single subscription
|
||||
# E.g. for exchange 'test_direct':
|
||||
# -topic:'test_direct'
|
||||
# -bound queue:'direct1':
|
||||
# -subscription: direct1' on topic 'test_direct'
|
||||
# -filter:routing_key'
|
||||
filter_args = {
|
||||
'filter': f'attributes.routing_key="{routing_key}"'
|
||||
}
|
||||
subscription_path = self.subscriber.subscription_path(
|
||||
self.project_id, queue
|
||||
)
|
||||
message_retention_duration = self.expiration_seconds
|
||||
elif exchange_type == 'fanout':
|
||||
# Fanout exchange is implemented as a separate subscription.
|
||||
# E.g. for exchange 'test_fanout':
|
||||
# -topic:'test_fanout'
|
||||
# -bound queue 'fanout1':
|
||||
# -subscription:'fanout1-uuid' on topic 'test_fanout'
|
||||
# -bound queue 'fanout2':
|
||||
# -subscription:'fanout2-uuid' on topic 'test_fanout'
|
||||
uid = f'{uuid3(NAMESPACE_OID, f"{gethostname()}.{getpid()}")}'
|
||||
uniq_sub_name = f'{queue}-{uid}'
|
||||
subscription_path = self.subscriber.subscription_path(
|
||||
self.project_id, uniq_sub_name
|
||||
)
|
||||
self._tmp_subscriptions.add(subscription_path)
|
||||
self._fanout_exchanges.add(exchange)
|
||||
message_retention_duration = 600
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'exchange type {exchange_type} not implemented'
|
||||
)
|
||||
exchange_topic = self._create_topic(
|
||||
self.project_id, exchange, message_retention_duration
|
||||
)
|
||||
self._create_subscription(
|
||||
topic_path=exchange_topic,
|
||||
subscription_path=subscription_path,
|
||||
filter_args=filter_args,
|
||||
msg_retention=message_retention_duration,
|
||||
)
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path=exchange_topic,
|
||||
subscription_id=queue,
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
self._queue_cache[queue] = qdesc
|
||||
|
||||
def _create_topic(
|
||||
self,
|
||||
project_id: str,
|
||||
topic_id: str,
|
||||
message_retention_duration: int = None,
|
||||
) -> str:
|
||||
topic_path = self.publisher.topic_path(project_id, topic_id)
|
||||
if self._is_topic_exists(topic_path):
|
||||
# topic creation takes a while, so skip if possible
|
||||
logger.debug('topic: %s exists', topic_path)
|
||||
return topic_path
|
||||
try:
|
||||
logger.debug('creating topic: %s', topic_path)
|
||||
request = {'name': topic_path}
|
||||
if message_retention_duration:
|
||||
request[
|
||||
'message_retention_duration'
|
||||
] = f'{message_retention_duration}s'
|
||||
self.publisher.create_topic(request=request)
|
||||
except AlreadyExists:
|
||||
pass
|
||||
|
||||
return topic_path
|
||||
|
||||
def _is_topic_exists(self, topic_path: str) -> bool:
|
||||
topics = self.publisher.list_topics(
|
||||
request={"project": f'projects/{self.project_id}'}
|
||||
)
|
||||
for t in topics:
|
||||
if t.name == topic_path:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _create_subscription(
|
||||
self,
|
||||
project_id: str = None,
|
||||
topic_id: str = None,
|
||||
topic_path: str = None,
|
||||
subscription_path: str = None,
|
||||
filter_args=None,
|
||||
msg_retention: int = None,
|
||||
) -> str:
|
||||
subscription_path = (
|
||||
subscription_path
|
||||
or self.subscriber.subscription_path(self.project_id, topic_id)
|
||||
)
|
||||
topic_path = topic_path or self.publisher.topic_path(
|
||||
project_id, topic_id
|
||||
)
|
||||
try:
|
||||
logger.debug(
|
||||
'creating subscription: %s, topic: %s, filter: %s',
|
||||
subscription_path,
|
||||
topic_path,
|
||||
filter_args,
|
||||
)
|
||||
msg_retention = msg_retention or self.expiration_seconds
|
||||
self.subscriber.create_subscription(
|
||||
request={
|
||||
"name": subscription_path,
|
||||
"topic": topic_path,
|
||||
'ack_deadline_seconds': self.ack_deadline_seconds,
|
||||
'expiration_policy': {
|
||||
'ttl': f'{self.expiration_seconds}s'
|
||||
},
|
||||
'message_retention_duration': f'{msg_retention}s',
|
||||
**(filter_args or {}),
|
||||
}
|
||||
)
|
||||
except AlreadyExists:
|
||||
pass
|
||||
return subscription_path
|
||||
|
||||
def _delete(self, queue, *args, **kwargs):
|
||||
"""Delete a queue by name."""
|
||||
queue = self.entity_name(queue)
|
||||
logger.info('deleting queue: %s', queue)
|
||||
qdesc = self._queue_cache.get(queue)
|
||||
if not qdesc:
|
||||
return
|
||||
self.subscriber.delete_subscription(
|
||||
request={"subscription": qdesc.subscription_path}
|
||||
)
|
||||
self._queue_cache.pop(queue, None)
|
||||
|
||||
def _put(self, queue, message, **kwargs):
|
||||
"""Put a message onto the queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[queue]
|
||||
routing_key = self._get_routing_key(message)
|
||||
logger.debug(
|
||||
'putting message to queue: %s, topic: %s, routing_key: %s',
|
||||
queue,
|
||||
qdesc.topic_path,
|
||||
routing_key,
|
||||
)
|
||||
encoded_message = dumps(message)
|
||||
self.publisher.publish(
|
||||
qdesc.topic_path,
|
||||
encoded_message.encode("utf-8"),
|
||||
routing_key=routing_key,
|
||||
)
|
||||
|
||||
def _put_fanout(self, exchange, message, routing_key, **kwargs):
|
||||
"""Put a message onto fanout exchange."""
|
||||
self._lookup(exchange, routing_key)
|
||||
topic_path = self.publisher.topic_path(self.project_id, exchange)
|
||||
logger.debug(
|
||||
'putting msg to fanout exchange: %s, topic: %s',
|
||||
exchange,
|
||||
topic_path,
|
||||
)
|
||||
encoded_message = dumps(message)
|
||||
self.publisher.publish(
|
||||
topic_path,
|
||||
encoded_message.encode("utf-8"),
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
)
|
||||
|
||||
def _get(self, queue: str, timeout: float = None):
|
||||
"""Retrieves a single message from a queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[queue]
|
||||
try:
|
||||
response = self.subscriber.pull(
|
||||
request={
|
||||
'subscription': qdesc.subscription_path,
|
||||
'max_messages': 1,
|
||||
},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
timeout=timeout or self.wait_time_seconds,
|
||||
)
|
||||
except DeadlineExceeded:
|
||||
raise Empty()
|
||||
|
||||
if len(response.received_messages) == 0:
|
||||
raise Empty()
|
||||
|
||||
message = response.received_messages[0]
|
||||
ack_id = message.ack_id
|
||||
payload = loads(message.message.data)
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
logger.debug(
|
||||
'queue:%s got message, ack_id: %s, payload: %s',
|
||||
queue,
|
||||
ack_id,
|
||||
payload['properties'],
|
||||
)
|
||||
if self._is_auto_ack(payload['properties']):
|
||||
logger.debug('auto acking message ack_id: %s', ack_id)
|
||||
self._do_ack([ack_id], qdesc.subscription_path)
|
||||
else:
|
||||
delivery_info['gcpubsub_message'] = {
|
||||
'queue': queue,
|
||||
'ack_id': ack_id,
|
||||
'message_id': message.message.message_id,
|
||||
'subscription_path': qdesc.subscription_path,
|
||||
}
|
||||
qdesc.unacked_ids.append(ack_id)
|
||||
|
||||
return payload
|
||||
|
||||
def _is_auto_ack(self, payload_properties: dict):
|
||||
exchange = payload_properties['delivery_info']['exchange']
|
||||
delivery_mode = payload_properties['delivery_mode']
|
||||
return (
|
||||
delivery_mode == TRANSIENT_DELIVERY_MODE
|
||||
or exchange in self._fanout_exchanges
|
||||
)
|
||||
|
||||
def _get_bulk(self, queue: str, timeout: float):
|
||||
"""Retrieves bulk of messages from a queue."""
|
||||
prefixed_queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache[prefixed_queue]
|
||||
max_messages = self._get_max_messages_estimate()
|
||||
if not max_messages:
|
||||
raise Empty()
|
||||
try:
|
||||
response = self.subscriber.pull(
|
||||
request={
|
||||
'subscription': qdesc.subscription_path,
|
||||
'max_messages': max_messages,
|
||||
},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
timeout=timeout or self.wait_time_seconds,
|
||||
)
|
||||
except DeadlineExceeded:
|
||||
raise Empty()
|
||||
|
||||
received_messages = response.received_messages
|
||||
if len(received_messages) == 0:
|
||||
raise Empty()
|
||||
|
||||
auto_ack_ids = []
|
||||
ret_payloads = []
|
||||
logger.debug(
|
||||
'batching %d messages from queue: %s',
|
||||
len(received_messages),
|
||||
prefixed_queue,
|
||||
)
|
||||
for message in received_messages:
|
||||
ack_id = message.ack_id
|
||||
payload = loads(bytes_to_str(message.message.data))
|
||||
delivery_info = payload['properties']['delivery_info']
|
||||
delivery_info['gcpubsub_message'] = {
|
||||
'queue': prefixed_queue,
|
||||
'ack_id': ack_id,
|
||||
'message_id': message.message.message_id,
|
||||
'subscription_path': qdesc.subscription_path,
|
||||
}
|
||||
if self._is_auto_ack(payload['properties']):
|
||||
auto_ack_ids.append(ack_id)
|
||||
else:
|
||||
qdesc.unacked_ids.append(ack_id)
|
||||
ret_payloads.append(payload)
|
||||
if auto_ack_ids:
|
||||
logger.debug('auto acking ack_ids: %s', auto_ack_ids)
|
||||
self._do_ack(auto_ack_ids, qdesc.subscription_path)
|
||||
|
||||
return queue, ret_payloads
|
||||
|
||||
def _get_max_messages_estimate(self) -> int:
|
||||
max_allowed = self.qos.can_consume_max_estimate()
|
||||
max_if_unlimited = self.bulk_max_messages
|
||||
return max_if_unlimited if max_allowed is None else max_allowed
|
||||
|
||||
def _lookup(self, exchange, routing_key, default=None):
|
||||
exchange_info = self.state.exchanges.get(exchange, {})
|
||||
if not exchange_info:
|
||||
return super()._lookup(exchange, routing_key, default)
|
||||
ret = self.typeof(exchange).lookup(
|
||||
self.get_table(exchange),
|
||||
exchange,
|
||||
routing_key,
|
||||
default,
|
||||
)
|
||||
if ret:
|
||||
return ret
|
||||
logger.debug(
|
||||
'no queues bound to exchange: %s, binding on the fly',
|
||||
exchange,
|
||||
)
|
||||
self.queue_bind(exchange, exchange, routing_key)
|
||||
return [exchange]
|
||||
|
||||
def _size(self, queue: str) -> int:
|
||||
"""Return the number of messages in a queue.
|
||||
|
||||
This is a *rough* estimation, as Pub/Sub doesn't provide
|
||||
an exact API.
|
||||
"""
|
||||
queue = self.entity_name(queue)
|
||||
if queue not in self._queue_cache:
|
||||
return 0
|
||||
qdesc = self._queue_cache[queue]
|
||||
result = query.Query(
|
||||
self.monitor,
|
||||
self.project_id,
|
||||
'pubsub.googleapis.com/subscription/num_undelivered_messages',
|
||||
end_time=datetime.datetime.now(),
|
||||
minutes=1,
|
||||
).select_resources(subscription_id=qdesc.subscription_id)
|
||||
|
||||
# monitoring API requires the caller to have the monitoring.viewer
|
||||
# role. Since we can live without the exact number of messages
|
||||
# in the queue, we can ignore the exception and allow users to
|
||||
# use the transport without this role.
|
||||
with suppress(PermissionDenied):
|
||||
return sum(
|
||||
content.points[0].value.int64_value for content in result
|
||||
)
|
||||
return -1
|
||||
|
||||
def basic_ack(self, delivery_tag, multiple=False):
|
||||
"""Acknowledge one message."""
|
||||
if multiple:
|
||||
raise NotImplementedError('multiple acks not implemented')
|
||||
|
||||
delivery_info = self.qos.get(delivery_tag).delivery_info
|
||||
pubsub_message = delivery_info['gcpubsub_message']
|
||||
ack_id = pubsub_message['ack_id']
|
||||
queue = pubsub_message['queue']
|
||||
logger.debug('ack message. queue: %s ack_id: %s', queue, ack_id)
|
||||
subscription_path = pubsub_message['subscription_path']
|
||||
self._do_ack([ack_id], subscription_path)
|
||||
qdesc = self._queue_cache[queue]
|
||||
qdesc.unacked_ids.remove(ack_id)
|
||||
super().basic_ack(delivery_tag)
|
||||
|
||||
def _do_ack(self, ack_ids: list[str], subscription_path: str):
|
||||
self.subscriber.acknowledge(
|
||||
request={"subscription": subscription_path, "ack_ids": ack_ids},
|
||||
retry=Retry(deadline=self.retry_timeout_seconds),
|
||||
)
|
||||
|
||||
def _purge(self, queue: str):
|
||||
"""Delete all current messages in a queue."""
|
||||
queue = self.entity_name(queue)
|
||||
qdesc = self._queue_cache.get(queue)
|
||||
if not qdesc:
|
||||
return
|
||||
|
||||
n = self._size(queue)
|
||||
self.subscriber.seek(
|
||||
request={
|
||||
"subscription": qdesc.subscription_path,
|
||||
"time": datetime.datetime.now(),
|
||||
}
|
||||
)
|
||||
return n
|
||||
|
||||
def _extend_unacked_deadline(self):
|
||||
thread_id = threading.get_native_id()
|
||||
logger.info(
|
||||
'unacked deadline extension thread: [%s] started',
|
||||
thread_id,
|
||||
)
|
||||
min_deadline_sleep = self._min_ack_deadline / 2
|
||||
sleep_time = max(min_deadline_sleep, self.ack_deadline_seconds / 4)
|
||||
while not self._stop_extender.wait(sleep_time):
|
||||
for qdesc in self._queue_cache.values():
|
||||
if len(qdesc.unacked_ids) == 0:
|
||||
logger.debug(
|
||||
'thread [%s]: no unacked messages for %s',
|
||||
thread_id,
|
||||
qdesc.subscription_path,
|
||||
)
|
||||
continue
|
||||
logger.debug(
|
||||
'thread [%s]: extend ack deadline for %s: %d msgs [%s]',
|
||||
thread_id,
|
||||
qdesc.subscription_path,
|
||||
len(qdesc.unacked_ids),
|
||||
list(qdesc.unacked_ids),
|
||||
)
|
||||
self.subscriber.modify_ack_deadline(
|
||||
request={
|
||||
"subscription": qdesc.subscription_path,
|
||||
"ack_ids": list(qdesc.unacked_ids),
|
||||
"ack_deadline_seconds": self.ack_deadline_seconds,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
'unacked deadline extension thread [%s] stopped', thread_id
|
||||
)
|
||||
|
||||
def after_reply_message_received(self, queue: str):
|
||||
queue = self.entity_name(queue)
|
||||
sub = self.subscriber.subscription_path(self.project_id, queue)
|
||||
logger.debug(
|
||||
'after_reply_message_received: queue: %s, sub: %s', queue, sub
|
||||
)
|
||||
self._tmp_subscriptions.add(sub)
|
||||
|
||||
@cached_property
|
||||
def subscriber(self):
|
||||
return SubscriberClient()
|
||||
|
||||
@cached_property
|
||||
def publisher(self):
|
||||
return PublisherClient()
|
||||
|
||||
@cached_property
|
||||
def monitor(self):
|
||||
return monitoring_v3.MetricServiceClient()
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
return self.connection.client
|
||||
|
||||
@property
|
||||
def transport_options(self):
|
||||
return self.connection.client.transport_options
|
||||
|
||||
@cached_property
|
||||
def wait_time_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'wait_time_seconds', self.default_wait_time_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def retry_timeout_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'retry_timeout_seconds', self.default_retry_timeout_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def ack_deadline_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'ack_deadline_seconds', self.default_ack_deadline_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def queue_name_prefix(self):
|
||||
return self.transport_options.get('queue_name_prefix', 'kombu-')
|
||||
|
||||
@cached_property
|
||||
def expiration_seconds(self):
|
||||
return self.transport_options.get(
|
||||
'expiration_seconds', self.default_expiration_seconds
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def bulk_max_messages(self):
|
||||
return self.transport_options.get(
|
||||
'bulk_max_messages', self.default_bulk_max_messages
|
||||
)
|
||||
|
||||
def close(self):
|
||||
"""Close the channel."""
|
||||
logger.debug('closing channel')
|
||||
while self._tmp_subscriptions:
|
||||
sub = self._tmp_subscriptions.pop()
|
||||
with suppress(Exception):
|
||||
logger.debug('deleting subscription: %s', sub)
|
||||
self.subscriber.delete_subscription(
|
||||
request={"subscription": sub}
|
||||
)
|
||||
if not self._n_channels.dec():
|
||||
self._stop_extender.set()
|
||||
Channel._unacked_extender.join()
|
||||
super().close()
|
||||
|
||||
@staticmethod
|
||||
def _get_routing_key(message):
|
||||
routing_key = (
|
||||
message['properties']
|
||||
.get('delivery_info', {})
|
||||
.get('routing_key', '')
|
||||
)
|
||||
return routing_key
|
||||
|
||||
|
||||
class Transport(virtual.Transport):
|
||||
"""GCP Pub/Sub transport."""
|
||||
|
||||
Channel = Channel
|
||||
|
||||
can_parse_url = True
|
||||
polling_interval = 0.1
|
||||
connection_errors = virtual.Transport.connection_errors + (
|
||||
pubsub_exceptions.TimeoutError,
|
||||
)
|
||||
channel_errors = (
|
||||
virtual.Transport.channel_errors
|
||||
+ (
|
||||
publisher_exceptions.FlowControlLimitError,
|
||||
publisher_exceptions.MessageTooLargeError,
|
||||
publisher_exceptions.PublishError,
|
||||
publisher_exceptions.TimeoutError,
|
||||
publisher_exceptions.PublishToPausedOrderingKeyException,
|
||||
)
|
||||
+ (subscriber_exceptions.AcknowledgeError,)
|
||||
)
|
||||
|
||||
driver_type = 'gcpubsub'
|
||||
driver_name = 'pubsub_v1'
|
||||
|
||||
implements = virtual.Transport.implements.extend(
|
||||
exchange_type=frozenset(['direct', 'fanout']),
|
||||
)
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
super().__init__(client, **kwargs)
|
||||
self._pool = ThreadPoolExecutor()
|
||||
self._get_bulk_future_to_queue: dict[Future, str] = dict()
|
||||
|
||||
def driver_version(self):
|
||||
return package_version.__version__
|
||||
|
||||
@staticmethod
|
||||
def parse_uri(uri: str) -> str:
|
||||
# URL like:
|
||||
# gcpubsub://projects/project-name
|
||||
|
||||
project = uri.split('gcpubsub://projects/')[1]
|
||||
return project.strip('/')
|
||||
|
||||
@classmethod
|
||||
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
|
||||
return uri or 'gcpubsub://'
|
||||
|
||||
def drain_events(self, connection, timeout=None):
|
||||
time_start = monotonic()
|
||||
polling_interval = self.polling_interval
|
||||
if timeout and polling_interval and polling_interval > timeout:
|
||||
polling_interval = timeout
|
||||
while 1:
|
||||
try:
|
||||
self._drain_from_active_queues(timeout=timeout)
|
||||
except Empty:
|
||||
if timeout and monotonic() - time_start >= timeout:
|
||||
raise socket_timeout()
|
||||
if polling_interval:
|
||||
sleep(polling_interval)
|
||||
else:
|
||||
break
|
||||
|
||||
def _drain_from_active_queues(self, timeout):
|
||||
# cleanup empty requests from prev run
|
||||
self._rm_empty_bulk_requests()
|
||||
|
||||
# submit new requests for all active queues
|
||||
# longer timeout means less frequent polling
|
||||
# and more messages in a single bulk
|
||||
self._submit_get_bulk_requests(timeout=10)
|
||||
|
||||
done, _ = wait(
|
||||
self._get_bulk_future_to_queue,
|
||||
timeout=timeout,
|
||||
return_when=FIRST_COMPLETED,
|
||||
)
|
||||
empty = {f for f in done if f.exception()}
|
||||
done -= empty
|
||||
for f in empty:
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
if not done:
|
||||
raise Empty()
|
||||
|
||||
logger.debug('got %d done get_bulk tasks', len(done))
|
||||
for f in done:
|
||||
queue, payloads = f.result()
|
||||
for payload in payloads:
|
||||
logger.debug('consuming message from queue: %s', queue)
|
||||
if queue not in self._callbacks:
|
||||
logger.warning(
|
||||
'Message for queue %s without consumers', queue
|
||||
)
|
||||
continue
|
||||
self._deliver(payload, queue)
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
def _rm_empty_bulk_requests(self):
|
||||
empty = {
|
||||
f
|
||||
for f in self._get_bulk_future_to_queue
|
||||
if f.done() and f.exception()
|
||||
}
|
||||
for f in empty:
|
||||
self._get_bulk_future_to_queue.pop(f, None)
|
||||
|
||||
def _submit_get_bulk_requests(self, timeout):
|
||||
queues_with_submitted_get_bulk = set(
|
||||
self._get_bulk_future_to_queue.values()
|
||||
)
|
||||
|
||||
for channel in self.channels:
|
||||
for queue in channel._active_queues:
|
||||
if queue in queues_with_submitted_get_bulk:
|
||||
continue
|
||||
future = self._pool.submit(channel._get_bulk, queue, timeout)
|
||||
self._get_bulk_future_to_queue[future] = queue
|
|
@ -0,0 +1,3 @@
|
|||
google-cloud-pubsub>=2.18.4
|
||||
google-cloud-monitoring>=2.16.0
|
||||
grpcio==1.66.2
|
|
@ -16,3 +16,4 @@ urllib3>=1.26.16; sys_platform != 'win32'
|
|||
-r extras/zstd.txt
|
||||
-r extras/sqlalchemy.txt
|
||||
-r extras/etcd.txt
|
||||
-r extras/gcpubsub.txt
|
||||
|
|
1
setup.py
1
setup.py
|
@ -103,6 +103,7 @@ setup(
|
|||
'redis': extras('redis.txt'),
|
||||
'mongodb': extras('mongodb.txt'),
|
||||
'sqs': extras('sqs.txt'),
|
||||
'gcpubsub': extras('gcpubsub.txt'),
|
||||
'zookeeper': extras('zookeeper.txt'),
|
||||
'sqlalchemy': extras('sqlalchemy.txt'),
|
||||
'librabbitmq': extras('librabbitmq.txt'),
|
||||
|
|
|
@ -0,0 +1,793 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import Future
|
||||
from datetime import datetime
|
||||
from queue import Empty
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from _socket import timeout as socket_timeout
|
||||
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
|
||||
PermissionDenied)
|
||||
|
||||
from kombu.transport.gcpubsub import (AtomicCounter, Channel, QueueDescriptor,
|
||||
Transport, UnackedIds)
|
||||
|
||||
|
||||
class test_UnackedIds:
|
||||
def setup_method(self):
|
||||
self.unacked_ids = UnackedIds()
|
||||
|
||||
def test_append(self):
|
||||
self.unacked_ids.append('test_id')
|
||||
assert self.unacked_ids[0] == 'test_id'
|
||||
|
||||
def test_extend(self):
|
||||
self.unacked_ids.extend(['test_id1', 'test_id2'])
|
||||
assert self.unacked_ids[0] == 'test_id1'
|
||||
assert self.unacked_ids[1] == 'test_id2'
|
||||
|
||||
def test_pop(self):
|
||||
self.unacked_ids.append('test_id')
|
||||
popped_id = self.unacked_ids.pop()
|
||||
assert popped_id == 'test_id'
|
||||
assert len(self.unacked_ids) == 0
|
||||
|
||||
def test_remove(self):
|
||||
self.unacked_ids.append('test_id')
|
||||
self.unacked_ids.remove('test_id')
|
||||
assert len(self.unacked_ids) == 0
|
||||
|
||||
def test_len(self):
|
||||
self.unacked_ids.append('test_id')
|
||||
assert len(self.unacked_ids) == 1
|
||||
|
||||
def test_getitem(self):
|
||||
self.unacked_ids.append('test_id')
|
||||
assert self.unacked_ids[0] == 'test_id'
|
||||
|
||||
|
||||
class test_AtomicCounter:
|
||||
def setup_method(self):
|
||||
self.counter = AtomicCounter()
|
||||
|
||||
def test_inc(self):
|
||||
assert self.counter.inc() == 1
|
||||
assert self.counter.inc(5) == 6
|
||||
|
||||
def test_dec(self):
|
||||
self.counter.inc(5)
|
||||
assert self.counter.dec() == 4
|
||||
assert self.counter.dec(2) == 2
|
||||
|
||||
def test_get(self):
|
||||
self.counter.inc(7)
|
||||
assert self.counter.get() == 7
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def channel():
|
||||
with patch.object(Channel, '__init__', lambda self: None):
|
||||
channel = Channel()
|
||||
channel.connection = MagicMock()
|
||||
channel.queue_name_prefix = "kombu-"
|
||||
channel.project_id = "test_project"
|
||||
channel._queue_cache = {}
|
||||
channel._n_channels = MagicMock()
|
||||
channel._stop_extender = MagicMock()
|
||||
channel.subscriber = MagicMock()
|
||||
channel.publisher = MagicMock()
|
||||
channel.closed = False
|
||||
with patch.object(
|
||||
Channel, 'conninfo', new_callable=MagicMock
|
||||
), patch.object(
|
||||
Channel, 'transport_options', new_callable=MagicMock
|
||||
), patch.object(
|
||||
Channel, 'qos', new_callable=MagicMock
|
||||
):
|
||||
yield channel
|
||||
|
||||
|
||||
class test_Channel:
|
||||
@patch('kombu.transport.gcpubsub.ThreadPoolExecutor')
|
||||
@patch('kombu.transport.gcpubsub.threading.Event')
|
||||
@patch('kombu.transport.gcpubsub.threading.Thread')
|
||||
@patch(
|
||||
'kombu.transport.gcpubsub.Channel._get_free_channel_id',
|
||||
return_value=1,
|
||||
)
|
||||
@patch(
|
||||
'kombu.transport.gcpubsub.Channel._n_channels.inc',
|
||||
return_value=1,
|
||||
)
|
||||
def test_channel_init(
|
||||
self,
|
||||
n_channels_in_mock,
|
||||
channel_id_mock,
|
||||
mock_thread,
|
||||
mock_event,
|
||||
mock_executor,
|
||||
):
|
||||
mock_connection = MagicMock()
|
||||
ch = Channel(mock_connection)
|
||||
ch._n_channels.inc.assert_called_once()
|
||||
mock_thread.assert_called_once_with(
|
||||
target=ch._extend_unacked_deadline,
|
||||
daemon=True,
|
||||
)
|
||||
mock_thread.return_value.start.assert_called_once()
|
||||
|
||||
def test_entity_name(self, channel):
|
||||
name = "test_queue"
|
||||
result = channel.entity_name(name)
|
||||
assert result == "kombu-test_queue"
|
||||
|
||||
@patch('kombu.transport.gcpubsub.uuid3', return_value='uuid')
|
||||
@patch('kombu.transport.gcpubsub.gethostname', return_value='hostname')
|
||||
@patch('kombu.transport.gcpubsub.getpid', return_value=1234)
|
||||
def test_queue_bind_direct(
|
||||
self, mock_pid, mock_hostname, mock_uuid, channel
|
||||
):
|
||||
exchange = 'direct'
|
||||
routing_key = 'test_routing_key'
|
||||
pattern = 'test_pattern'
|
||||
queue = 'test_queue'
|
||||
subscription_path = 'projects/project-id/subscriptions/test_queue'
|
||||
channel.subscriber.subscription_path = MagicMock(
|
||||
return_value=subscription_path
|
||||
)
|
||||
channel._create_topic = MagicMock(return_value='topic_path')
|
||||
channel._create_subscription = MagicMock()
|
||||
|
||||
# Mock the state and exchange type
|
||||
mock_connection = MagicMock(name='mock_connection')
|
||||
channel.connection = mock_connection
|
||||
channel.state.exchanges = {exchange: {'type': 'direct'}}
|
||||
mock_exchange = MagicMock(name='mock_exchange', type='direct')
|
||||
channel.exchange_types = {'direct': mock_exchange}
|
||||
|
||||
channel._queue_bind(exchange, routing_key, pattern, queue)
|
||||
|
||||
channel._create_topic.assert_called_once_with(
|
||||
channel.project_id, exchange, channel.expiration_seconds
|
||||
)
|
||||
channel._create_subscription.assert_called_once_with(
|
||||
topic_path='topic_path',
|
||||
subscription_path=subscription_path,
|
||||
filter_args={'filter': f'attributes.routing_key="{routing_key}"'},
|
||||
msg_retention=channel.expiration_seconds,
|
||||
)
|
||||
assert channel.entity_name(queue) in channel._queue_cache
|
||||
|
||||
@patch('kombu.transport.gcpubsub.uuid3', return_value='uuid')
|
||||
@patch('kombu.transport.gcpubsub.gethostname', return_value='hostname')
|
||||
@patch('kombu.transport.gcpubsub.getpid', return_value=1234)
|
||||
def test_queue_bind_fanout(
|
||||
self, mock_pid, mock_hostname, mock_uuid, channel
|
||||
):
|
||||
exchange = 'test_exchange'
|
||||
routing_key = 'test_routing_key'
|
||||
pattern = 'test_pattern'
|
||||
queue = 'test_queue'
|
||||
uniq_sub_name = 'test_queue-uuid'
|
||||
subscription_path = (
|
||||
f'projects/project-id/subscriptions/{uniq_sub_name}'
|
||||
)
|
||||
channel.subscriber.subscription_path = MagicMock(
|
||||
return_value=subscription_path
|
||||
)
|
||||
channel._create_topic = MagicMock(return_value='topic_path')
|
||||
channel._create_subscription = MagicMock()
|
||||
|
||||
# Mock the state and exchange type
|
||||
mock_connection = MagicMock(name='mock_connection')
|
||||
channel.connection = mock_connection
|
||||
channel.state.exchanges = {exchange: {'type': 'fanout'}}
|
||||
mock_exchange = MagicMock(name='mock_exchange', type='fanout')
|
||||
channel.exchange_types = {'fanout': mock_exchange}
|
||||
|
||||
channel._queue_bind(exchange, routing_key, pattern, queue)
|
||||
|
||||
channel._create_topic.assert_called_once_with(
|
||||
channel.project_id, exchange, 600
|
||||
)
|
||||
channel._create_subscription.assert_called_once_with(
|
||||
topic_path='topic_path',
|
||||
subscription_path=subscription_path,
|
||||
filter_args={},
|
||||
msg_retention=600,
|
||||
)
|
||||
assert channel.entity_name(queue) in channel._queue_cache
|
||||
assert subscription_path in channel._tmp_subscriptions
|
||||
assert exchange in channel._fanout_exchanges
|
||||
|
||||
def test_queue_bind_not_implemented(self, channel):
|
||||
exchange = 'test_exchange'
|
||||
routing_key = 'test_routing_key'
|
||||
pattern = 'test_pattern'
|
||||
queue = 'test_queue'
|
||||
channel.typeof = MagicMock(return_value=MagicMock(type='unsupported'))
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
channel._queue_bind(exchange, routing_key, pattern, queue)
|
||||
|
||||
def test_create_topic(self, channel):
|
||||
channel.project_id = "project_id"
|
||||
topic_id = "topic_id"
|
||||
channel._is_topic_exists = MagicMock(return_value=False)
|
||||
channel.publisher.topic_path = MagicMock(return_value="topic_path")
|
||||
channel.publisher.create_topic = MagicMock()
|
||||
result = channel._create_topic(channel.project_id, topic_id)
|
||||
assert result == "topic_path"
|
||||
channel.publisher.create_topic.assert_called_once()
|
||||
|
||||
channel._create_topic(
|
||||
channel.project_id, topic_id, message_retention_duration=10
|
||||
)
|
||||
assert (
|
||||
dict(
|
||||
request={
|
||||
'name': 'topic_path',
|
||||
'message_retention_duration': '10s',
|
||||
}
|
||||
)
|
||||
in channel.publisher.create_topic.call_args
|
||||
)
|
||||
channel.publisher.create_topic.side_effect = AlreadyExists(
|
||||
"test_error"
|
||||
)
|
||||
channel._create_topic(
|
||||
channel.project_id, topic_id, message_retention_duration=10
|
||||
)
|
||||
|
||||
def test_is_topic_exists(self, channel):
|
||||
topic_path = "projects/project-id/topics/test_topic"
|
||||
mock_topic = MagicMock()
|
||||
mock_topic.name = topic_path
|
||||
channel.publisher.list_topics.return_value = [mock_topic]
|
||||
|
||||
result = channel._is_topic_exists(topic_path)
|
||||
|
||||
assert result is True
|
||||
channel.publisher.list_topics.assert_called_once_with(
|
||||
request={"project": f'projects/{channel.project_id}'}
|
||||
)
|
||||
|
||||
def test_is_topic_not_exists(self, channel):
|
||||
topic_path = "projects/project-id/topics/test_topic"
|
||||
channel.publisher.list_topics.return_value = []
|
||||
|
||||
result = channel._is_topic_exists(topic_path)
|
||||
|
||||
assert result is False
|
||||
channel.publisher.list_topics.assert_called_once_with(
|
||||
request={"project": f'projects/{channel.project_id}'}
|
||||
)
|
||||
|
||||
def test_create_subscription(self, channel):
|
||||
channel.project_id = "project_id"
|
||||
topic_id = "topic_id"
|
||||
subscription_path = "subscription_path"
|
||||
topic_path = "topic_path"
|
||||
channel.subscriber.subscription_path = MagicMock(
|
||||
return_value=subscription_path
|
||||
)
|
||||
channel.publisher.topic_path = MagicMock(return_value=topic_path)
|
||||
channel.subscriber.create_subscription = MagicMock()
|
||||
result = channel._create_subscription(
|
||||
project_id=channel.project_id,
|
||||
topic_id=topic_id,
|
||||
subscription_path=subscription_path,
|
||||
topic_path=topic_path,
|
||||
)
|
||||
assert result == subscription_path
|
||||
channel.subscriber.create_subscription.assert_called_once()
|
||||
|
||||
def test_delete(self, channel):
|
||||
queue = "test_queue"
|
||||
subscription_path = "projects/project-id/subscriptions/test_queue"
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="projects/project-id/topics/test_topic",
|
||||
subscription_id=queue,
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
channel.subscriber = MagicMock()
|
||||
channel._queue_cache[channel.entity_name(queue)] = qdesc
|
||||
|
||||
channel._delete(queue)
|
||||
|
||||
channel.subscriber.delete_subscription.assert_called_once_with(
|
||||
request={"subscription": subscription_path}
|
||||
)
|
||||
assert queue not in channel._queue_cache
|
||||
|
||||
def test_put(self, channel):
|
||||
queue = "test_queue"
|
||||
message = {
|
||||
"properties": {"delivery_info": {"routing_key": "test_key"}}
|
||||
}
|
||||
channel.entity_name = MagicMock(return_value=queue)
|
||||
channel._queue_cache[channel.entity_name(queue)] = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="topic_path",
|
||||
subscription_id=queue,
|
||||
subscription_path="subscription_path",
|
||||
)
|
||||
channel._get_routing_key = MagicMock(return_value="test_key")
|
||||
channel.publisher.publish = MagicMock()
|
||||
channel._put(queue, message)
|
||||
channel.publisher.publish.assert_called_once()
|
||||
|
||||
def test_put_fanout(self, channel):
|
||||
exchange = "test_exchange"
|
||||
message = {
|
||||
"properties": {"delivery_info": {"routing_key": "test_key"}}
|
||||
}
|
||||
routing_key = "test_key"
|
||||
|
||||
channel._lookup = MagicMock()
|
||||
channel.publisher.topic_path = MagicMock(return_value="topic_path")
|
||||
channel.publisher.publish = MagicMock()
|
||||
|
||||
channel._put_fanout(exchange, message, routing_key)
|
||||
|
||||
channel._lookup.assert_called_once_with(exchange, routing_key)
|
||||
channel.publisher.topic_path.assert_called_once_with(
|
||||
channel.project_id, exchange
|
||||
)
|
||||
assert 'topic_path', (
|
||||
b'{"properties": {"delivery_info": {"routing_key": "test_key"}}}'
|
||||
in channel.publisher.publish.call_args
|
||||
)
|
||||
|
||||
def test_get(self, channel):
|
||||
queue = "test_queue"
|
||||
channel.entity_name = MagicMock(return_value=queue)
|
||||
channel._queue_cache[queue] = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="topic_path",
|
||||
subscription_id=queue,
|
||||
subscription_path="subscription_path",
|
||||
)
|
||||
channel.subscriber.pull = MagicMock(
|
||||
return_value=MagicMock(
|
||||
received_messages=[
|
||||
MagicMock(
|
||||
ack_id="ack_id",
|
||||
message=MagicMock(
|
||||
data=b'{"properties": '
|
||||
b'{"delivery_info": '
|
||||
b'{"exchange": "exchange"},"delivery_mode": 1}}'
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
channel.subscriber.acknowledge = MagicMock()
|
||||
payload = channel._get(queue)
|
||||
assert (
|
||||
payload["properties"]["delivery_info"]["exchange"] == "exchange"
|
||||
)
|
||||
channel.subscriber.pull.side_effect = DeadlineExceeded("test_error")
|
||||
with pytest.raises(Empty):
|
||||
channel._get(queue)
|
||||
|
||||
def test_get_bulk(self, channel):
|
||||
queue = "test_queue"
|
||||
subscription_path = "projects/project-id/subscriptions/test_queue"
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="projects/project-id/topics/test_topic",
|
||||
subscription_id=queue,
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
channel._queue_cache[channel.entity_name(queue)] = qdesc
|
||||
|
||||
data = b'{"properties": {"delivery_info": {"exchange": "exchange"}}}'
|
||||
received_message = MagicMock(
|
||||
ack_id="ack_id",
|
||||
message=MagicMock(data=data),
|
||||
)
|
||||
channel.subscriber.pull = MagicMock(
|
||||
return_value=MagicMock(received_messages=[received_message])
|
||||
)
|
||||
channel.bulk_max_messages = 10
|
||||
channel._is_auto_ack = MagicMock(return_value=True)
|
||||
channel._do_ack = MagicMock()
|
||||
channel.qos.can_consume_max_estimate = MagicMock(return_value=None)
|
||||
queue, payloads = channel._get_bulk(queue, timeout=10)
|
||||
|
||||
assert len(payloads) == 1
|
||||
assert (
|
||||
payloads[0]["properties"]["delivery_info"]["exchange"]
|
||||
== "exchange"
|
||||
)
|
||||
channel._do_ack.assert_called_once_with(["ack_id"], subscription_path)
|
||||
|
||||
channel.subscriber.pull.side_effect = DeadlineExceeded("test_error")
|
||||
with pytest.raises(Empty):
|
||||
channel._get_bulk(queue, timeout=10)
|
||||
|
||||
def test_lookup(self, channel):
|
||||
exchange = "test_exchange"
|
||||
routing_key = "test_key"
|
||||
default = None
|
||||
|
||||
channel.connection = MagicMock()
|
||||
channel.state.exchanges = {exchange: {"type": "direct"}}
|
||||
channel.typeof = MagicMock(
|
||||
return_value=MagicMock(lookup=MagicMock(return_value=["queue1"]))
|
||||
)
|
||||
channel.get_table = MagicMock(return_value="table")
|
||||
|
||||
result = channel._lookup(exchange, routing_key, default)
|
||||
|
||||
channel.typeof.return_value.lookup.assert_called_once_with(
|
||||
"table", exchange, routing_key, default
|
||||
)
|
||||
assert result == ["queue1"]
|
||||
|
||||
# Test the case where no queues are bound to the exchange
|
||||
channel.typeof.return_value.lookup.return_value = None
|
||||
channel.queue_bind = MagicMock()
|
||||
|
||||
result = channel._lookup(exchange, routing_key, default)
|
||||
|
||||
channel.queue_bind.assert_called_once_with(
|
||||
exchange, exchange, routing_key
|
||||
)
|
||||
assert result == [exchange]
|
||||
|
||||
@patch('kombu.transport.gcpubsub.monitoring_v3')
|
||||
@patch('kombu.transport.gcpubsub.query.Query')
|
||||
def test_size(self, mock_query, mock_monitor, channel):
|
||||
queue = "test_queue"
|
||||
subscription_id = "test_subscription"
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="projects/project-id/topics/test_topic",
|
||||
subscription_id=subscription_id,
|
||||
subscription_path="projects/project-id/subscriptions/test_subscription", # E501
|
||||
)
|
||||
channel._queue_cache[channel.entity_name(queue)] = qdesc
|
||||
|
||||
mock_query_result = MagicMock()
|
||||
mock_query_result.select_resources.return_value = [
|
||||
MagicMock(points=[MagicMock(value=MagicMock(int64_value=5))])
|
||||
]
|
||||
mock_query.return_value = mock_query_result
|
||||
size = channel._size(queue)
|
||||
assert size == 5
|
||||
|
||||
# Test the case where the queue is not in the cache
|
||||
size = channel._size("non_existent_queue")
|
||||
assert size == 0
|
||||
|
||||
# Test the case where the query raises PermissionDenied
|
||||
mock_item = MagicMock()
|
||||
mock_item.points.__getitem__.side_effect = PermissionDenied(
|
||||
'test_error'
|
||||
)
|
||||
mock_query_result.select_resources.return_value = [mock_item]
|
||||
size = channel._size(queue)
|
||||
assert size == -1
|
||||
|
||||
def test_basic_ack(self, channel):
|
||||
delivery_tag = "test_delivery_tag"
|
||||
ack_id = "test_ack_id"
|
||||
queue = "test_queue"
|
||||
subscription_path = (
|
||||
"projects/project-id/subscriptions/test_subscription"
|
||||
)
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="projects/project-id/topics/test_topic",
|
||||
subscription_id="test_subscription",
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
channel._queue_cache[queue] = qdesc
|
||||
|
||||
delivery_info = {
|
||||
'gcpubsub_message': {
|
||||
'queue': queue,
|
||||
'ack_id': ack_id,
|
||||
'subscription_path': subscription_path,
|
||||
}
|
||||
}
|
||||
channel.qos.get = MagicMock(
|
||||
return_value=MagicMock(delivery_info=delivery_info)
|
||||
)
|
||||
channel._do_ack = MagicMock()
|
||||
|
||||
channel.basic_ack(delivery_tag)
|
||||
|
||||
channel._do_ack.assert_called_once_with([ack_id], subscription_path)
|
||||
assert ack_id not in qdesc.unacked_ids
|
||||
|
||||
def test_do_ack(self, channel):
|
||||
ack_ids = ["ack_id1", "ack_id2"]
|
||||
subscription_path = (
|
||||
"projects/project-id/subscriptions/test_subscription"
|
||||
)
|
||||
channel.subscriber = MagicMock()
|
||||
|
||||
channel._do_ack(ack_ids, subscription_path)
|
||||
assert subscription_path, (
|
||||
ack_ids in channel.subscriber.acknowledge.call_args
|
||||
)
|
||||
|
||||
def test_purge(self, channel):
|
||||
queue = "test_queue"
|
||||
subscription_path = f"projects/project-id/subscriptions/{queue}"
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="projects/project-id/topics/test_topic",
|
||||
subscription_id="test_subscription",
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
channel._queue_cache[channel.entity_name(queue)] = qdesc
|
||||
channel.subscriber = MagicMock()
|
||||
|
||||
with patch.object(channel, '_size', return_value=10), patch(
|
||||
'kombu.transport.gcpubsub.datetime.datetime'
|
||||
) as dt_mock:
|
||||
dt_mock.now.return_value = datetime(2021, 1, 1)
|
||||
result = channel._purge(queue)
|
||||
assert result == 10
|
||||
channel.subscriber.seek.assert_called_once_with(
|
||||
request={
|
||||
"subscription": subscription_path,
|
||||
"time": datetime(2021, 1, 1),
|
||||
}
|
||||
)
|
||||
|
||||
# Test the case where the queue is not in the cache
|
||||
result = channel._purge("non_existent_queue")
|
||||
assert result is None
|
||||
|
||||
def test_extend_unacked_deadline(self, channel):
|
||||
queue = "test_queue"
|
||||
subscription_path = (
|
||||
"projects/project-id/subscriptions/test_subscription"
|
||||
)
|
||||
ack_ids = ["ack_id1", "ack_id2"]
|
||||
qdesc = QueueDescriptor(
|
||||
name=queue,
|
||||
topic_path="projects/project-id/topics/test_topic",
|
||||
subscription_id="test_subscription",
|
||||
subscription_path=subscription_path,
|
||||
)
|
||||
channel.transport_options = {"ack_deadline_seconds": 240}
|
||||
channel._queue_cache[channel.entity_name(queue)] = qdesc
|
||||
qdesc.unacked_ids.extend(ack_ids)
|
||||
|
||||
channel._stop_extender.wait = MagicMock(side_effect=[False, True])
|
||||
channel.subscriber.modify_ack_deadline = MagicMock()
|
||||
|
||||
channel._extend_unacked_deadline()
|
||||
|
||||
channel.subscriber.modify_ack_deadline.assert_called_once_with(
|
||||
request={
|
||||
"subscription": subscription_path,
|
||||
"ack_ids": ack_ids,
|
||||
"ack_deadline_seconds": 240,
|
||||
}
|
||||
)
|
||||
for _ in ack_ids:
|
||||
qdesc.unacked_ids.pop()
|
||||
channel._stop_extender.wait = MagicMock(side_effect=[False, True])
|
||||
modify_ack_deadline_calls = (
|
||||
channel.subscriber.modify_ack_deadline.call_count
|
||||
)
|
||||
channel._extend_unacked_deadline()
|
||||
assert (
|
||||
channel.subscriber.modify_ack_deadline.call_count
|
||||
== modify_ack_deadline_calls
|
||||
)
|
||||
|
||||
def test_after_reply_message_received(self, channel):
|
||||
queue = 'test-queue'
|
||||
subscription_path = f'projects/test-project/subscriptions/{queue}'
|
||||
|
||||
channel.subscriber.subscription_path.return_value = subscription_path
|
||||
channel.after_reply_message_received(queue)
|
||||
|
||||
# Check that the subscription path is added to _tmp_subscriptions
|
||||
assert subscription_path in channel._tmp_subscriptions
|
||||
|
||||
def test_subscriber(self, channel):
|
||||
assert channel.subscriber
|
||||
|
||||
def test_publisher(self, channel):
|
||||
assert channel.publisher
|
||||
|
||||
def test_transport_options(self, channel):
|
||||
assert channel.transport_options
|
||||
|
||||
def test_bulk_max_messages_default(self, channel):
|
||||
assert channel.bulk_max_messages == channel.transport_options.get(
|
||||
'bulk_max_messages'
|
||||
)
|
||||
|
||||
def test_close(self, channel):
|
||||
channel._tmp_subscriptions = {'sub1', 'sub2'}
|
||||
channel._n_channels.dec.return_value = 0
|
||||
|
||||
with patch.object(
|
||||
Channel._unacked_extender, 'join'
|
||||
) as mock_join, patch(
|
||||
'kombu.transport.virtual.Channel.close'
|
||||
) as mock_super_close:
|
||||
channel.close()
|
||||
|
||||
channel.subscriber.delete_subscription.assert_has_calls(
|
||||
[
|
||||
call(request={"subscription": 'sub1'}),
|
||||
call(request={"subscription": 'sub2'}),
|
||||
],
|
||||
any_order=True,
|
||||
)
|
||||
channel._stop_extender.set.assert_called_once()
|
||||
mock_join.assert_called_once()
|
||||
mock_super_close.assert_called_once()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transport():
|
||||
return Transport(client=MagicMock())
|
||||
|
||||
|
||||
class test_Transport:
|
||||
def test_driver_version(self, transport):
|
||||
assert transport.driver_version()
|
||||
|
||||
def test_as_uri(self, transport):
|
||||
result = transport.as_uri('gcpubsub://')
|
||||
assert result == 'gcpubsub://'
|
||||
|
||||
def test_drain_events_timeout(self, transport):
|
||||
transport.polling_interval = 4
|
||||
with patch.object(
|
||||
transport, '_drain_from_active_queues', side_effect=Empty
|
||||
), patch(
|
||||
'kombu.transport.gcpubsub.monotonic',
|
||||
side_effect=[0, 1, 2, 3, 4, 5],
|
||||
), patch(
|
||||
'kombu.transport.gcpubsub.sleep'
|
||||
) as mock_sleep:
|
||||
with pytest.raises(socket_timeout):
|
||||
transport.drain_events(None, timeout=3)
|
||||
mock_sleep.assert_called()
|
||||
|
||||
def test_drain_events_no_timeout(self, transport):
|
||||
with patch.object(
|
||||
transport, '_drain_from_active_queues', side_effect=[Empty, None]
|
||||
), patch(
|
||||
'kombu.transport.gcpubsub.monotonic', side_effect=[0, 1]
|
||||
), patch(
|
||||
'kombu.transport.gcpubsub.sleep'
|
||||
) as mock_sleep:
|
||||
transport.drain_events(None, timeout=None)
|
||||
mock_sleep.assert_called()
|
||||
|
||||
def test_drain_events_polling_interval(self, transport):
|
||||
transport.polling_interval = 2
|
||||
with patch.object(
|
||||
transport, '_drain_from_active_queues', side_effect=[Empty, None]
|
||||
), patch(
|
||||
'kombu.transport.gcpubsub.monotonic', side_effect=[0, 1, 2]
|
||||
), patch(
|
||||
'kombu.transport.gcpubsub.sleep'
|
||||
) as mock_sleep:
|
||||
transport.drain_events(None, timeout=5)
|
||||
mock_sleep.assert_called_with(2)
|
||||
|
||||
def test_drain_from_active_queues_empty(self, transport):
|
||||
with patch.object(
|
||||
transport, '_rm_empty_bulk_requests'
|
||||
) as mock_rm_empty, patch.object(
|
||||
transport, '_submit_get_bulk_requests'
|
||||
) as mock_submit, patch(
|
||||
'kombu.transport.gcpubsub.wait', return_value=(set(), set())
|
||||
) as mock_wait:
|
||||
with pytest.raises(Empty):
|
||||
transport._drain_from_active_queues(timeout=10)
|
||||
mock_rm_empty.assert_called_once()
|
||||
mock_submit.assert_called_once_with(timeout=10)
|
||||
mock_wait.assert_called_once()
|
||||
|
||||
def test_drain_from_active_queues_done(self, transport):
|
||||
future = Future()
|
||||
future.set_result(('queue', [{'properties': {'delivery_info': {}}}]))
|
||||
|
||||
with patch.object(
|
||||
transport, '_rm_empty_bulk_requests'
|
||||
) as mock_rm_empty, patch.object(
|
||||
transport, '_submit_get_bulk_requests'
|
||||
) as mock_submit, patch(
|
||||
'kombu.transport.gcpubsub.wait', return_value=({future}, set())
|
||||
) as mock_wait, patch.object(
|
||||
transport, '_deliver'
|
||||
) as mock_deliver:
|
||||
transport._callbacks = {'queue'}
|
||||
transport._drain_from_active_queues(timeout=10)
|
||||
mock_rm_empty.assert_called_once()
|
||||
mock_submit.assert_called_once_with(timeout=10)
|
||||
mock_wait.assert_called_once()
|
||||
mock_deliver.assert_called_once_with(
|
||||
{'properties': {'delivery_info': {}}}, 'queue'
|
||||
)
|
||||
|
||||
mock_deliver_call_count = mock_deliver.call_count
|
||||
transport._callbacks = {}
|
||||
transport._drain_from_active_queues(timeout=10)
|
||||
assert mock_deliver_call_count == mock_deliver.call_count
|
||||
|
||||
def test_drain_from_active_queues_exception(self, transport):
|
||||
future = Future()
|
||||
future.set_exception(Exception("Test exception"))
|
||||
|
||||
with patch.object(
|
||||
transport, '_rm_empty_bulk_requests'
|
||||
) as mock_rm_empty, patch.object(
|
||||
transport, '_submit_get_bulk_requests'
|
||||
) as mock_submit, patch(
|
||||
'kombu.transport.gcpubsub.wait', return_value=({future}, set())
|
||||
) as mock_wait:
|
||||
with pytest.raises(Empty):
|
||||
transport._drain_from_active_queues(timeout=10)
|
||||
mock_rm_empty.assert_called_once()
|
||||
mock_submit.assert_called_once_with(timeout=10)
|
||||
mock_wait.assert_called_once()
|
||||
|
||||
def test_rm_empty_bulk_requests(self, transport):
|
||||
# Create futures with exceptions to simulate empty requests
|
||||
future_with_exception = Future()
|
||||
future_with_exception.set_exception(Exception("Test exception"))
|
||||
|
||||
transport._get_bulk_future_to_queue = {
|
||||
future_with_exception: 'queue1',
|
||||
}
|
||||
|
||||
transport._rm_empty_bulk_requests()
|
||||
|
||||
# Assert that the future with exception is removed
|
||||
assert (
|
||||
future_with_exception not in transport._get_bulk_future_to_queue
|
||||
)
|
||||
|
||||
def test_submit_get_bulk_requests(self, transport):
|
||||
channel_mock = MagicMock(spec=Channel)
|
||||
channel_mock._active_queues = ['queue1', 'queue2']
|
||||
transport.channels = [channel_mock]
|
||||
|
||||
with patch.object(
|
||||
transport._pool, 'submit', return_value=MagicMock()
|
||||
) as mock_submit:
|
||||
transport._submit_get_bulk_requests(timeout=10)
|
||||
|
||||
# Check that submit was called twice, once for each queue
|
||||
assert mock_submit.call_count == 2
|
||||
mock_submit.assert_any_call(channel_mock._get_bulk, 'queue1', 10)
|
||||
mock_submit.assert_any_call(channel_mock._get_bulk, 'queue2', 10)
|
||||
|
||||
def test_submit_get_bulk_requests_with_existing_futures(self, transport):
|
||||
channel_mock = MagicMock(spec=Channel)
|
||||
channel_mock._active_queues = ['queue1', 'queue2']
|
||||
transport.channels = [channel_mock]
|
||||
|
||||
# Simulate existing futures
|
||||
future_mock = MagicMock()
|
||||
transport._get_bulk_future_to_queue = {future_mock: 'queue1'}
|
||||
|
||||
with patch.object(
|
||||
transport._pool, 'submit', return_value=MagicMock()
|
||||
) as mock_submit:
|
||||
transport._submit_get_bulk_requests(timeout=10)
|
||||
|
||||
# Check that submit was called only once for the new queue
|
||||
assert mock_submit.call_count == 1
|
||||
mock_submit.assert_called_with(
|
||||
channel_mock._get_bulk, 'queue2', 10
|
||||
)
|
Loading…
Reference in New Issue