Add support for Google Pub/Sub as transport broker (#2147)

* Add support for Google Pub/Sub as transport broker

 * Add tests
 * Add docs

* flake8

* flake8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix future import

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add missing test requirements

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add monitoring dependency

* Fix test for python3.8

* Mock better google's monitoring api

* Flake8

* Add refdoc

* Add extra url to workaround pypy grpcio

* Add extra index url in tox for grpcio/pypy support

* Revert "Add extra url to workaround pypy grpcio"

This reverts commit dfde4d523c.

* pin grpcio version to match extra_index wheel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reduce poll calls if qos denies msg rx

---------

Co-authored-by: Haim Daniel <haimdaniel@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tomer Nosrati <tomer.nosrati@gmail.com>
This commit is contained in:
Haim Daniel 2024-10-13 16:49:55 +03:00 committed by GitHub
parent 5a88a28f31
commit 2f58823312
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1637 additions and 2 deletions

View File

@ -0,0 +1,24 @@
==============================================================
Google Cloud Pub/Sub Transport - ``kombu.transport.gcpubsub``
==============================================================
.. currentmodule:: kombu.transport.gcpubsub
.. automodule:: kombu.transport.gcpubsub
.. contents::
:local:
Transport
---------
.. autoclass:: Transport
:members:
:undoc-members:
Channel
-------
.. autoclass:: Channel
:members:
:undoc-members:

View File

@ -43,7 +43,8 @@ TRANSPORT_ALIASES = {
'etcd': 'kombu.transport.etcd:Transport',
'azurestoragequeues': 'kombu.transport.azurestoragequeues:Transport',
'azureservicebus': 'kombu.transport.azureservicebus:Transport',
'pyro': 'kombu.transport.pyro:Transport'
'pyro': 'kombu.transport.pyro:Transport',
'gcpubsub': 'kombu.transport.gcpubsub:Transport',
}
_transport_cache = {}

810
kombu/transport/gcpubsub.py Normal file
View File

@ -0,0 +1,810 @@
"""GCP Pub/Sub transport module for kombu.
More information about GCP Pub/Sub:
https://cloud.google.com/pubsub
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: No
* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
Connection String
=================
Connection string has the following formats:
.. code-block::
gcpubsub://projects/project-name
Transport Options
=================
* ``queue_name_prefix``: (str) Prefix for queue names.
* ``ack_deadline_seconds``: (int) The maximum time after receiving a message
and acknowledging it before pub/sub redelivers the message.
* ``expiration_seconds``: (int) Subscriptions without any subscriber
activity or changes made to their properties are removed after this period.
Examples of subscriber activities include open connections,
active pulls, or successful pushes.
* ``wait_time_seconds``: (int) The maximum time to wait for new messages.
Defaults to 10.
* ``retry_timeout_seconds``: (int) The maximum time to wait before retrying.
* ``bulk_max_messages``: (int) The maximum number of messages to pull in bulk.
Defaults to 32.
"""
from __future__ import annotations
import dataclasses
import datetime
import string
import threading
from concurrent.futures import (FIRST_COMPLETED, Future, ThreadPoolExecutor,
wait)
from contextlib import suppress
from os import getpid
from queue import Empty
from threading import Lock
from time import monotonic, sleep
from uuid import NAMESPACE_OID, uuid3
from _socket import gethostname
from _socket import timeout as socket_timeout
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
PermissionDenied)
from google.api_core.retry import Retry
from google.cloud import monitoring_v3
from google.cloud.monitoring_v3 import query
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
from google.cloud.pubsub_v1 import exceptions as pubsub_exceptions
from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions
from google.cloud.pubsub_v1.subscriber import \
exceptions as subscriber_exceptions
from google.pubsub_v1 import gapic_version as package_version
from kombu.entity import TRANSIENT_DELIVERY_MODE
from kombu.log import get_logger
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
logger = get_logger('kombu.transport.gcpubsub')
# dots are replaced by dash, all other punctuation replaced by underscore.
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
CHARS_REPLACE_TABLE = {
ord('.'): ord('-'),
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE},
}
class UnackedIds:
"""Threadsafe list of ack_ids."""
def __init__(self):
self._list = []
self._lock = Lock()
def append(self, val):
# append is atomic
self._list.append(val)
def extend(self, vals: list):
# extend is atomic
self._list.extend(vals)
def pop(self, index=-1):
with self._lock:
return self._list.pop(index)
def remove(self, val):
with self._lock, suppress(ValueError):
self._list.remove(val)
def __len__(self):
with self._lock:
return len(self._list)
def __getitem__(self, item):
# getitem is atomic
return self._list[item]
class AtomicCounter:
"""Threadsafe counter.
Returns the value after inc/dec operations.
"""
def __init__(self, initial=0):
self._value = initial
self._lock = Lock()
def inc(self, n=1):
with self._lock:
self._value += n
return self._value
def dec(self, n=1):
with self._lock:
self._value -= n
return self._value
def get(self):
with self._lock:
return self._value
@dataclasses.dataclass
class QueueDescriptor:
"""Pub/Sub queue descriptor."""
name: str
topic_path: str # projects/{project_id}/topics/{topic_id}
subscription_id: str
subscription_path: str # projects/{project_id}/subscriptions/{subscription_id}
unacked_ids: UnackedIds = dataclasses.field(default_factory=UnackedIds)
class Channel(virtual.Channel):
"""GCP Pub/Sub channel."""
supports_fanout = True
do_restore = False # pub/sub does that for us
default_wait_time_seconds = 10
default_ack_deadline_seconds = 240
default_expiration_seconds = 86400
default_retry_timeout_seconds = 300
default_bulk_max_messages = 32
_min_ack_deadline = 10
_fanout_exchanges = set()
_unacked_extender: threading.Thread = None
_stop_extender = threading.Event()
_n_channels = AtomicCounter()
_queue_cache: dict[str, QueueDescriptor] = {}
_tmp_subscriptions: set[str] = set()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pool = ThreadPoolExecutor()
logger.info('new GCP pub/sub channel: %s', self.conninfo.hostname)
self.project_id = Transport.parse_uri(self.conninfo.hostname)
if self._n_channels.inc() == 1:
Channel._unacked_extender = threading.Thread(
target=self._extend_unacked_deadline,
daemon=True,
)
self._stop_extender.clear()
Channel._unacked_extender.start()
def entity_name(self, name: str, table=CHARS_REPLACE_TABLE) -> str:
"""Format AMQP queue name into a valid Pub/Sub queue name."""
if not name.startswith(self.queue_name_prefix):
name = self.queue_name_prefix + name
return str(safe_str(name)).translate(table)
def _queue_bind(self, exchange, routing_key, pattern, queue):
exchange_type = self.typeof(exchange).type
queue = self.entity_name(queue)
logger.debug(
'binding queue: %s to %s exchange: %s with routing_key: %s',
queue,
exchange_type,
exchange,
routing_key,
)
filter_args = {}
if exchange_type == 'direct':
# Direct exchange is implemented as a single subscription
# E.g. for exchange 'test_direct':
# -topic:'test_direct'
# -bound queue:'direct1':
# -subscription: direct1' on topic 'test_direct'
# -filter:routing_key'
filter_args = {
'filter': f'attributes.routing_key="{routing_key}"'
}
subscription_path = self.subscriber.subscription_path(
self.project_id, queue
)
message_retention_duration = self.expiration_seconds
elif exchange_type == 'fanout':
# Fanout exchange is implemented as a separate subscription.
# E.g. for exchange 'test_fanout':
# -topic:'test_fanout'
# -bound queue 'fanout1':
# -subscription:'fanout1-uuid' on topic 'test_fanout'
# -bound queue 'fanout2':
# -subscription:'fanout2-uuid' on topic 'test_fanout'
uid = f'{uuid3(NAMESPACE_OID, f"{gethostname()}.{getpid()}")}'
uniq_sub_name = f'{queue}-{uid}'
subscription_path = self.subscriber.subscription_path(
self.project_id, uniq_sub_name
)
self._tmp_subscriptions.add(subscription_path)
self._fanout_exchanges.add(exchange)
message_retention_duration = 600
else:
raise NotImplementedError(
f'exchange type {exchange_type} not implemented'
)
exchange_topic = self._create_topic(
self.project_id, exchange, message_retention_duration
)
self._create_subscription(
topic_path=exchange_topic,
subscription_path=subscription_path,
filter_args=filter_args,
msg_retention=message_retention_duration,
)
qdesc = QueueDescriptor(
name=queue,
topic_path=exchange_topic,
subscription_id=queue,
subscription_path=subscription_path,
)
self._queue_cache[queue] = qdesc
def _create_topic(
self,
project_id: str,
topic_id: str,
message_retention_duration: int = None,
) -> str:
topic_path = self.publisher.topic_path(project_id, topic_id)
if self._is_topic_exists(topic_path):
# topic creation takes a while, so skip if possible
logger.debug('topic: %s exists', topic_path)
return topic_path
try:
logger.debug('creating topic: %s', topic_path)
request = {'name': topic_path}
if message_retention_duration:
request[
'message_retention_duration'
] = f'{message_retention_duration}s'
self.publisher.create_topic(request=request)
except AlreadyExists:
pass
return topic_path
def _is_topic_exists(self, topic_path: str) -> bool:
topics = self.publisher.list_topics(
request={"project": f'projects/{self.project_id}'}
)
for t in topics:
if t.name == topic_path:
return True
return False
def _create_subscription(
self,
project_id: str = None,
topic_id: str = None,
topic_path: str = None,
subscription_path: str = None,
filter_args=None,
msg_retention: int = None,
) -> str:
subscription_path = (
subscription_path
or self.subscriber.subscription_path(self.project_id, topic_id)
)
topic_path = topic_path or self.publisher.topic_path(
project_id, topic_id
)
try:
logger.debug(
'creating subscription: %s, topic: %s, filter: %s',
subscription_path,
topic_path,
filter_args,
)
msg_retention = msg_retention or self.expiration_seconds
self.subscriber.create_subscription(
request={
"name": subscription_path,
"topic": topic_path,
'ack_deadline_seconds': self.ack_deadline_seconds,
'expiration_policy': {
'ttl': f'{self.expiration_seconds}s'
},
'message_retention_duration': f'{msg_retention}s',
**(filter_args or {}),
}
)
except AlreadyExists:
pass
return subscription_path
def _delete(self, queue, *args, **kwargs):
"""Delete a queue by name."""
queue = self.entity_name(queue)
logger.info('deleting queue: %s', queue)
qdesc = self._queue_cache.get(queue)
if not qdesc:
return
self.subscriber.delete_subscription(
request={"subscription": qdesc.subscription_path}
)
self._queue_cache.pop(queue, None)
def _put(self, queue, message, **kwargs):
"""Put a message onto the queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache[queue]
routing_key = self._get_routing_key(message)
logger.debug(
'putting message to queue: %s, topic: %s, routing_key: %s',
queue,
qdesc.topic_path,
routing_key,
)
encoded_message = dumps(message)
self.publisher.publish(
qdesc.topic_path,
encoded_message.encode("utf-8"),
routing_key=routing_key,
)
def _put_fanout(self, exchange, message, routing_key, **kwargs):
"""Put a message onto fanout exchange."""
self._lookup(exchange, routing_key)
topic_path = self.publisher.topic_path(self.project_id, exchange)
logger.debug(
'putting msg to fanout exchange: %s, topic: %s',
exchange,
topic_path,
)
encoded_message = dumps(message)
self.publisher.publish(
topic_path,
encoded_message.encode("utf-8"),
retry=Retry(deadline=self.retry_timeout_seconds),
)
def _get(self, queue: str, timeout: float = None):
"""Retrieves a single message from a queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache[queue]
try:
response = self.subscriber.pull(
request={
'subscription': qdesc.subscription_path,
'max_messages': 1,
},
retry=Retry(deadline=self.retry_timeout_seconds),
timeout=timeout or self.wait_time_seconds,
)
except DeadlineExceeded:
raise Empty()
if len(response.received_messages) == 0:
raise Empty()
message = response.received_messages[0]
ack_id = message.ack_id
payload = loads(message.message.data)
delivery_info = payload['properties']['delivery_info']
logger.debug(
'queue:%s got message, ack_id: %s, payload: %s',
queue,
ack_id,
payload['properties'],
)
if self._is_auto_ack(payload['properties']):
logger.debug('auto acking message ack_id: %s', ack_id)
self._do_ack([ack_id], qdesc.subscription_path)
else:
delivery_info['gcpubsub_message'] = {
'queue': queue,
'ack_id': ack_id,
'message_id': message.message.message_id,
'subscription_path': qdesc.subscription_path,
}
qdesc.unacked_ids.append(ack_id)
return payload
def _is_auto_ack(self, payload_properties: dict):
exchange = payload_properties['delivery_info']['exchange']
delivery_mode = payload_properties['delivery_mode']
return (
delivery_mode == TRANSIENT_DELIVERY_MODE
or exchange in self._fanout_exchanges
)
def _get_bulk(self, queue: str, timeout: float):
"""Retrieves bulk of messages from a queue."""
prefixed_queue = self.entity_name(queue)
qdesc = self._queue_cache[prefixed_queue]
max_messages = self._get_max_messages_estimate()
if not max_messages:
raise Empty()
try:
response = self.subscriber.pull(
request={
'subscription': qdesc.subscription_path,
'max_messages': max_messages,
},
retry=Retry(deadline=self.retry_timeout_seconds),
timeout=timeout or self.wait_time_seconds,
)
except DeadlineExceeded:
raise Empty()
received_messages = response.received_messages
if len(received_messages) == 0:
raise Empty()
auto_ack_ids = []
ret_payloads = []
logger.debug(
'batching %d messages from queue: %s',
len(received_messages),
prefixed_queue,
)
for message in received_messages:
ack_id = message.ack_id
payload = loads(bytes_to_str(message.message.data))
delivery_info = payload['properties']['delivery_info']
delivery_info['gcpubsub_message'] = {
'queue': prefixed_queue,
'ack_id': ack_id,
'message_id': message.message.message_id,
'subscription_path': qdesc.subscription_path,
}
if self._is_auto_ack(payload['properties']):
auto_ack_ids.append(ack_id)
else:
qdesc.unacked_ids.append(ack_id)
ret_payloads.append(payload)
if auto_ack_ids:
logger.debug('auto acking ack_ids: %s', auto_ack_ids)
self._do_ack(auto_ack_ids, qdesc.subscription_path)
return queue, ret_payloads
def _get_max_messages_estimate(self) -> int:
max_allowed = self.qos.can_consume_max_estimate()
max_if_unlimited = self.bulk_max_messages
return max_if_unlimited if max_allowed is None else max_allowed
def _lookup(self, exchange, routing_key, default=None):
exchange_info = self.state.exchanges.get(exchange, {})
if not exchange_info:
return super()._lookup(exchange, routing_key, default)
ret = self.typeof(exchange).lookup(
self.get_table(exchange),
exchange,
routing_key,
default,
)
if ret:
return ret
logger.debug(
'no queues bound to exchange: %s, binding on the fly',
exchange,
)
self.queue_bind(exchange, exchange, routing_key)
return [exchange]
def _size(self, queue: str) -> int:
"""Return the number of messages in a queue.
This is a *rough* estimation, as Pub/Sub doesn't provide
an exact API.
"""
queue = self.entity_name(queue)
if queue not in self._queue_cache:
return 0
qdesc = self._queue_cache[queue]
result = query.Query(
self.monitor,
self.project_id,
'pubsub.googleapis.com/subscription/num_undelivered_messages',
end_time=datetime.datetime.now(),
minutes=1,
).select_resources(subscription_id=qdesc.subscription_id)
# monitoring API requires the caller to have the monitoring.viewer
# role. Since we can live without the exact number of messages
# in the queue, we can ignore the exception and allow users to
# use the transport without this role.
with suppress(PermissionDenied):
return sum(
content.points[0].value.int64_value for content in result
)
return -1
def basic_ack(self, delivery_tag, multiple=False):
"""Acknowledge one message."""
if multiple:
raise NotImplementedError('multiple acks not implemented')
delivery_info = self.qos.get(delivery_tag).delivery_info
pubsub_message = delivery_info['gcpubsub_message']
ack_id = pubsub_message['ack_id']
queue = pubsub_message['queue']
logger.debug('ack message. queue: %s ack_id: %s', queue, ack_id)
subscription_path = pubsub_message['subscription_path']
self._do_ack([ack_id], subscription_path)
qdesc = self._queue_cache[queue]
qdesc.unacked_ids.remove(ack_id)
super().basic_ack(delivery_tag)
def _do_ack(self, ack_ids: list[str], subscription_path: str):
self.subscriber.acknowledge(
request={"subscription": subscription_path, "ack_ids": ack_ids},
retry=Retry(deadline=self.retry_timeout_seconds),
)
def _purge(self, queue: str):
"""Delete all current messages in a queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache.get(queue)
if not qdesc:
return
n = self._size(queue)
self.subscriber.seek(
request={
"subscription": qdesc.subscription_path,
"time": datetime.datetime.now(),
}
)
return n
def _extend_unacked_deadline(self):
thread_id = threading.get_native_id()
logger.info(
'unacked deadline extension thread: [%s] started',
thread_id,
)
min_deadline_sleep = self._min_ack_deadline / 2
sleep_time = max(min_deadline_sleep, self.ack_deadline_seconds / 4)
while not self._stop_extender.wait(sleep_time):
for qdesc in self._queue_cache.values():
if len(qdesc.unacked_ids) == 0:
logger.debug(
'thread [%s]: no unacked messages for %s',
thread_id,
qdesc.subscription_path,
)
continue
logger.debug(
'thread [%s]: extend ack deadline for %s: %d msgs [%s]',
thread_id,
qdesc.subscription_path,
len(qdesc.unacked_ids),
list(qdesc.unacked_ids),
)
self.subscriber.modify_ack_deadline(
request={
"subscription": qdesc.subscription_path,
"ack_ids": list(qdesc.unacked_ids),
"ack_deadline_seconds": self.ack_deadline_seconds,
}
)
logger.info(
'unacked deadline extension thread [%s] stopped', thread_id
)
def after_reply_message_received(self, queue: str):
queue = self.entity_name(queue)
sub = self.subscriber.subscription_path(self.project_id, queue)
logger.debug(
'after_reply_message_received: queue: %s, sub: %s', queue, sub
)
self._tmp_subscriptions.add(sub)
@cached_property
def subscriber(self):
return SubscriberClient()
@cached_property
def publisher(self):
return PublisherClient()
@cached_property
def monitor(self):
return monitoring_v3.MetricServiceClient()
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def wait_time_seconds(self):
return self.transport_options.get(
'wait_time_seconds', self.default_wait_time_seconds
)
@cached_property
def retry_timeout_seconds(self):
return self.transport_options.get(
'retry_timeout_seconds', self.default_retry_timeout_seconds
)
@cached_property
def ack_deadline_seconds(self):
return self.transport_options.get(
'ack_deadline_seconds', self.default_ack_deadline_seconds
)
@cached_property
def queue_name_prefix(self):
return self.transport_options.get('queue_name_prefix', 'kombu-')
@cached_property
def expiration_seconds(self):
return self.transport_options.get(
'expiration_seconds', self.default_expiration_seconds
)
@cached_property
def bulk_max_messages(self):
return self.transport_options.get(
'bulk_max_messages', self.default_bulk_max_messages
)
def close(self):
"""Close the channel."""
logger.debug('closing channel')
while self._tmp_subscriptions:
sub = self._tmp_subscriptions.pop()
with suppress(Exception):
logger.debug('deleting subscription: %s', sub)
self.subscriber.delete_subscription(
request={"subscription": sub}
)
if not self._n_channels.dec():
self._stop_extender.set()
Channel._unacked_extender.join()
super().close()
@staticmethod
def _get_routing_key(message):
routing_key = (
message['properties']
.get('delivery_info', {})
.get('routing_key', '')
)
return routing_key
class Transport(virtual.Transport):
"""GCP Pub/Sub transport."""
Channel = Channel
can_parse_url = True
polling_interval = 0.1
connection_errors = virtual.Transport.connection_errors + (
pubsub_exceptions.TimeoutError,
)
channel_errors = (
virtual.Transport.channel_errors
+ (
publisher_exceptions.FlowControlLimitError,
publisher_exceptions.MessageTooLargeError,
publisher_exceptions.PublishError,
publisher_exceptions.TimeoutError,
publisher_exceptions.PublishToPausedOrderingKeyException,
)
+ (subscriber_exceptions.AcknowledgeError,)
)
driver_type = 'gcpubsub'
driver_name = 'pubsub_v1'
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct', 'fanout']),
)
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self._pool = ThreadPoolExecutor()
self._get_bulk_future_to_queue: dict[Future, str] = dict()
def driver_version(self):
return package_version.__version__
@staticmethod
def parse_uri(uri: str) -> str:
# URL like:
# gcpubsub://projects/project-name
project = uri.split('gcpubsub://projects/')[1]
return project.strip('/')
@classmethod
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
return uri or 'gcpubsub://'
def drain_events(self, connection, timeout=None):
time_start = monotonic()
polling_interval = self.polling_interval
if timeout and polling_interval and polling_interval > timeout:
polling_interval = timeout
while 1:
try:
self._drain_from_active_queues(timeout=timeout)
except Empty:
if timeout and monotonic() - time_start >= timeout:
raise socket_timeout()
if polling_interval:
sleep(polling_interval)
else:
break
def _drain_from_active_queues(self, timeout):
# cleanup empty requests from prev run
self._rm_empty_bulk_requests()
# submit new requests for all active queues
# longer timeout means less frequent polling
# and more messages in a single bulk
self._submit_get_bulk_requests(timeout=10)
done, _ = wait(
self._get_bulk_future_to_queue,
timeout=timeout,
return_when=FIRST_COMPLETED,
)
empty = {f for f in done if f.exception()}
done -= empty
for f in empty:
self._get_bulk_future_to_queue.pop(f, None)
if not done:
raise Empty()
logger.debug('got %d done get_bulk tasks', len(done))
for f in done:
queue, payloads = f.result()
for payload in payloads:
logger.debug('consuming message from queue: %s', queue)
if queue not in self._callbacks:
logger.warning(
'Message for queue %s without consumers', queue
)
continue
self._deliver(payload, queue)
self._get_bulk_future_to_queue.pop(f, None)
def _rm_empty_bulk_requests(self):
empty = {
f
for f in self._get_bulk_future_to_queue
if f.done() and f.exception()
}
for f in empty:
self._get_bulk_future_to_queue.pop(f, None)
def _submit_get_bulk_requests(self, timeout):
queues_with_submitted_get_bulk = set(
self._get_bulk_future_to_queue.values()
)
for channel in self.channels:
for queue in channel._active_queues:
if queue in queues_with_submitted_get_bulk:
continue
future = self._pool.submit(channel._get_bulk, queue, timeout)
self._get_bulk_future_to_queue[future] = queue

View File

@ -0,0 +1,3 @@
google-cloud-pubsub>=2.18.4
google-cloud-monitoring>=2.16.0
grpcio==1.66.2

View File

@ -16,3 +16,4 @@ urllib3>=1.26.16; sys_platform != 'win32'
-r extras/zstd.txt
-r extras/sqlalchemy.txt
-r extras/etcd.txt
-r extras/gcpubsub.txt

View File

@ -103,6 +103,7 @@ setup(
'redis': extras('redis.txt'),
'mongodb': extras('mongodb.txt'),
'sqs': extras('sqs.txt'),
'gcpubsub': extras('gcpubsub.txt'),
'zookeeper': extras('zookeeper.txt'),
'sqlalchemy': extras('sqlalchemy.txt'),
'librabbitmq': extras('librabbitmq.txt'),

View File

@ -0,0 +1,793 @@
from __future__ import annotations
from concurrent.futures import Future
from datetime import datetime
from queue import Empty
from unittest.mock import MagicMock, call, patch
import pytest
from _socket import timeout as socket_timeout
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
PermissionDenied)
from kombu.transport.gcpubsub import (AtomicCounter, Channel, QueueDescriptor,
Transport, UnackedIds)
class test_UnackedIds:
def setup_method(self):
self.unacked_ids = UnackedIds()
def test_append(self):
self.unacked_ids.append('test_id')
assert self.unacked_ids[0] == 'test_id'
def test_extend(self):
self.unacked_ids.extend(['test_id1', 'test_id2'])
assert self.unacked_ids[0] == 'test_id1'
assert self.unacked_ids[1] == 'test_id2'
def test_pop(self):
self.unacked_ids.append('test_id')
popped_id = self.unacked_ids.pop()
assert popped_id == 'test_id'
assert len(self.unacked_ids) == 0
def test_remove(self):
self.unacked_ids.append('test_id')
self.unacked_ids.remove('test_id')
assert len(self.unacked_ids) == 0
def test_len(self):
self.unacked_ids.append('test_id')
assert len(self.unacked_ids) == 1
def test_getitem(self):
self.unacked_ids.append('test_id')
assert self.unacked_ids[0] == 'test_id'
class test_AtomicCounter:
def setup_method(self):
self.counter = AtomicCounter()
def test_inc(self):
assert self.counter.inc() == 1
assert self.counter.inc(5) == 6
def test_dec(self):
self.counter.inc(5)
assert self.counter.dec() == 4
assert self.counter.dec(2) == 2
def test_get(self):
self.counter.inc(7)
assert self.counter.get() == 7
@pytest.fixture
def channel():
with patch.object(Channel, '__init__', lambda self: None):
channel = Channel()
channel.connection = MagicMock()
channel.queue_name_prefix = "kombu-"
channel.project_id = "test_project"
channel._queue_cache = {}
channel._n_channels = MagicMock()
channel._stop_extender = MagicMock()
channel.subscriber = MagicMock()
channel.publisher = MagicMock()
channel.closed = False
with patch.object(
Channel, 'conninfo', new_callable=MagicMock
), patch.object(
Channel, 'transport_options', new_callable=MagicMock
), patch.object(
Channel, 'qos', new_callable=MagicMock
):
yield channel
class test_Channel:
@patch('kombu.transport.gcpubsub.ThreadPoolExecutor')
@patch('kombu.transport.gcpubsub.threading.Event')
@patch('kombu.transport.gcpubsub.threading.Thread')
@patch(
'kombu.transport.gcpubsub.Channel._get_free_channel_id',
return_value=1,
)
@patch(
'kombu.transport.gcpubsub.Channel._n_channels.inc',
return_value=1,
)
def test_channel_init(
self,
n_channels_in_mock,
channel_id_mock,
mock_thread,
mock_event,
mock_executor,
):
mock_connection = MagicMock()
ch = Channel(mock_connection)
ch._n_channels.inc.assert_called_once()
mock_thread.assert_called_once_with(
target=ch._extend_unacked_deadline,
daemon=True,
)
mock_thread.return_value.start.assert_called_once()
def test_entity_name(self, channel):
name = "test_queue"
result = channel.entity_name(name)
assert result == "kombu-test_queue"
@patch('kombu.transport.gcpubsub.uuid3', return_value='uuid')
@patch('kombu.transport.gcpubsub.gethostname', return_value='hostname')
@patch('kombu.transport.gcpubsub.getpid', return_value=1234)
def test_queue_bind_direct(
self, mock_pid, mock_hostname, mock_uuid, channel
):
exchange = 'direct'
routing_key = 'test_routing_key'
pattern = 'test_pattern'
queue = 'test_queue'
subscription_path = 'projects/project-id/subscriptions/test_queue'
channel.subscriber.subscription_path = MagicMock(
return_value=subscription_path
)
channel._create_topic = MagicMock(return_value='topic_path')
channel._create_subscription = MagicMock()
# Mock the state and exchange type
mock_connection = MagicMock(name='mock_connection')
channel.connection = mock_connection
channel.state.exchanges = {exchange: {'type': 'direct'}}
mock_exchange = MagicMock(name='mock_exchange', type='direct')
channel.exchange_types = {'direct': mock_exchange}
channel._queue_bind(exchange, routing_key, pattern, queue)
channel._create_topic.assert_called_once_with(
channel.project_id, exchange, channel.expiration_seconds
)
channel._create_subscription.assert_called_once_with(
topic_path='topic_path',
subscription_path=subscription_path,
filter_args={'filter': f'attributes.routing_key="{routing_key}"'},
msg_retention=channel.expiration_seconds,
)
assert channel.entity_name(queue) in channel._queue_cache
@patch('kombu.transport.gcpubsub.uuid3', return_value='uuid')
@patch('kombu.transport.gcpubsub.gethostname', return_value='hostname')
@patch('kombu.transport.gcpubsub.getpid', return_value=1234)
def test_queue_bind_fanout(
self, mock_pid, mock_hostname, mock_uuid, channel
):
exchange = 'test_exchange'
routing_key = 'test_routing_key'
pattern = 'test_pattern'
queue = 'test_queue'
uniq_sub_name = 'test_queue-uuid'
subscription_path = (
f'projects/project-id/subscriptions/{uniq_sub_name}'
)
channel.subscriber.subscription_path = MagicMock(
return_value=subscription_path
)
channel._create_topic = MagicMock(return_value='topic_path')
channel._create_subscription = MagicMock()
# Mock the state and exchange type
mock_connection = MagicMock(name='mock_connection')
channel.connection = mock_connection
channel.state.exchanges = {exchange: {'type': 'fanout'}}
mock_exchange = MagicMock(name='mock_exchange', type='fanout')
channel.exchange_types = {'fanout': mock_exchange}
channel._queue_bind(exchange, routing_key, pattern, queue)
channel._create_topic.assert_called_once_with(
channel.project_id, exchange, 600
)
channel._create_subscription.assert_called_once_with(
topic_path='topic_path',
subscription_path=subscription_path,
filter_args={},
msg_retention=600,
)
assert channel.entity_name(queue) in channel._queue_cache
assert subscription_path in channel._tmp_subscriptions
assert exchange in channel._fanout_exchanges
def test_queue_bind_not_implemented(self, channel):
exchange = 'test_exchange'
routing_key = 'test_routing_key'
pattern = 'test_pattern'
queue = 'test_queue'
channel.typeof = MagicMock(return_value=MagicMock(type='unsupported'))
with pytest.raises(NotImplementedError):
channel._queue_bind(exchange, routing_key, pattern, queue)
def test_create_topic(self, channel):
channel.project_id = "project_id"
topic_id = "topic_id"
channel._is_topic_exists = MagicMock(return_value=False)
channel.publisher.topic_path = MagicMock(return_value="topic_path")
channel.publisher.create_topic = MagicMock()
result = channel._create_topic(channel.project_id, topic_id)
assert result == "topic_path"
channel.publisher.create_topic.assert_called_once()
channel._create_topic(
channel.project_id, topic_id, message_retention_duration=10
)
assert (
dict(
request={
'name': 'topic_path',
'message_retention_duration': '10s',
}
)
in channel.publisher.create_topic.call_args
)
channel.publisher.create_topic.side_effect = AlreadyExists(
"test_error"
)
channel._create_topic(
channel.project_id, topic_id, message_retention_duration=10
)
def test_is_topic_exists(self, channel):
topic_path = "projects/project-id/topics/test_topic"
mock_topic = MagicMock()
mock_topic.name = topic_path
channel.publisher.list_topics.return_value = [mock_topic]
result = channel._is_topic_exists(topic_path)
assert result is True
channel.publisher.list_topics.assert_called_once_with(
request={"project": f'projects/{channel.project_id}'}
)
def test_is_topic_not_exists(self, channel):
topic_path = "projects/project-id/topics/test_topic"
channel.publisher.list_topics.return_value = []
result = channel._is_topic_exists(topic_path)
assert result is False
channel.publisher.list_topics.assert_called_once_with(
request={"project": f'projects/{channel.project_id}'}
)
def test_create_subscription(self, channel):
channel.project_id = "project_id"
topic_id = "topic_id"
subscription_path = "subscription_path"
topic_path = "topic_path"
channel.subscriber.subscription_path = MagicMock(
return_value=subscription_path
)
channel.publisher.topic_path = MagicMock(return_value=topic_path)
channel.subscriber.create_subscription = MagicMock()
result = channel._create_subscription(
project_id=channel.project_id,
topic_id=topic_id,
subscription_path=subscription_path,
topic_path=topic_path,
)
assert result == subscription_path
channel.subscriber.create_subscription.assert_called_once()
def test_delete(self, channel):
queue = "test_queue"
subscription_path = "projects/project-id/subscriptions/test_queue"
qdesc = QueueDescriptor(
name=queue,
topic_path="projects/project-id/topics/test_topic",
subscription_id=queue,
subscription_path=subscription_path,
)
channel.subscriber = MagicMock()
channel._queue_cache[channel.entity_name(queue)] = qdesc
channel._delete(queue)
channel.subscriber.delete_subscription.assert_called_once_with(
request={"subscription": subscription_path}
)
assert queue not in channel._queue_cache
def test_put(self, channel):
queue = "test_queue"
message = {
"properties": {"delivery_info": {"routing_key": "test_key"}}
}
channel.entity_name = MagicMock(return_value=queue)
channel._queue_cache[channel.entity_name(queue)] = QueueDescriptor(
name=queue,
topic_path="topic_path",
subscription_id=queue,
subscription_path="subscription_path",
)
channel._get_routing_key = MagicMock(return_value="test_key")
channel.publisher.publish = MagicMock()
channel._put(queue, message)
channel.publisher.publish.assert_called_once()
def test_put_fanout(self, channel):
exchange = "test_exchange"
message = {
"properties": {"delivery_info": {"routing_key": "test_key"}}
}
routing_key = "test_key"
channel._lookup = MagicMock()
channel.publisher.topic_path = MagicMock(return_value="topic_path")
channel.publisher.publish = MagicMock()
channel._put_fanout(exchange, message, routing_key)
channel._lookup.assert_called_once_with(exchange, routing_key)
channel.publisher.topic_path.assert_called_once_with(
channel.project_id, exchange
)
assert 'topic_path', (
b'{"properties": {"delivery_info": {"routing_key": "test_key"}}}'
in channel.publisher.publish.call_args
)
def test_get(self, channel):
queue = "test_queue"
channel.entity_name = MagicMock(return_value=queue)
channel._queue_cache[queue] = QueueDescriptor(
name=queue,
topic_path="topic_path",
subscription_id=queue,
subscription_path="subscription_path",
)
channel.subscriber.pull = MagicMock(
return_value=MagicMock(
received_messages=[
MagicMock(
ack_id="ack_id",
message=MagicMock(
data=b'{"properties": '
b'{"delivery_info": '
b'{"exchange": "exchange"},"delivery_mode": 1}}'
),
)
]
)
)
channel.subscriber.acknowledge = MagicMock()
payload = channel._get(queue)
assert (
payload["properties"]["delivery_info"]["exchange"] == "exchange"
)
channel.subscriber.pull.side_effect = DeadlineExceeded("test_error")
with pytest.raises(Empty):
channel._get(queue)
def test_get_bulk(self, channel):
queue = "test_queue"
subscription_path = "projects/project-id/subscriptions/test_queue"
qdesc = QueueDescriptor(
name=queue,
topic_path="projects/project-id/topics/test_topic",
subscription_id=queue,
subscription_path=subscription_path,
)
channel._queue_cache[channel.entity_name(queue)] = qdesc
data = b'{"properties": {"delivery_info": {"exchange": "exchange"}}}'
received_message = MagicMock(
ack_id="ack_id",
message=MagicMock(data=data),
)
channel.subscriber.pull = MagicMock(
return_value=MagicMock(received_messages=[received_message])
)
channel.bulk_max_messages = 10
channel._is_auto_ack = MagicMock(return_value=True)
channel._do_ack = MagicMock()
channel.qos.can_consume_max_estimate = MagicMock(return_value=None)
queue, payloads = channel._get_bulk(queue, timeout=10)
assert len(payloads) == 1
assert (
payloads[0]["properties"]["delivery_info"]["exchange"]
== "exchange"
)
channel._do_ack.assert_called_once_with(["ack_id"], subscription_path)
channel.subscriber.pull.side_effect = DeadlineExceeded("test_error")
with pytest.raises(Empty):
channel._get_bulk(queue, timeout=10)
def test_lookup(self, channel):
exchange = "test_exchange"
routing_key = "test_key"
default = None
channel.connection = MagicMock()
channel.state.exchanges = {exchange: {"type": "direct"}}
channel.typeof = MagicMock(
return_value=MagicMock(lookup=MagicMock(return_value=["queue1"]))
)
channel.get_table = MagicMock(return_value="table")
result = channel._lookup(exchange, routing_key, default)
channel.typeof.return_value.lookup.assert_called_once_with(
"table", exchange, routing_key, default
)
assert result == ["queue1"]
# Test the case where no queues are bound to the exchange
channel.typeof.return_value.lookup.return_value = None
channel.queue_bind = MagicMock()
result = channel._lookup(exchange, routing_key, default)
channel.queue_bind.assert_called_once_with(
exchange, exchange, routing_key
)
assert result == [exchange]
@patch('kombu.transport.gcpubsub.monitoring_v3')
@patch('kombu.transport.gcpubsub.query.Query')
def test_size(self, mock_query, mock_monitor, channel):
queue = "test_queue"
subscription_id = "test_subscription"
qdesc = QueueDescriptor(
name=queue,
topic_path="projects/project-id/topics/test_topic",
subscription_id=subscription_id,
subscription_path="projects/project-id/subscriptions/test_subscription", # E501
)
channel._queue_cache[channel.entity_name(queue)] = qdesc
mock_query_result = MagicMock()
mock_query_result.select_resources.return_value = [
MagicMock(points=[MagicMock(value=MagicMock(int64_value=5))])
]
mock_query.return_value = mock_query_result
size = channel._size(queue)
assert size == 5
# Test the case where the queue is not in the cache
size = channel._size("non_existent_queue")
assert size == 0
# Test the case where the query raises PermissionDenied
mock_item = MagicMock()
mock_item.points.__getitem__.side_effect = PermissionDenied(
'test_error'
)
mock_query_result.select_resources.return_value = [mock_item]
size = channel._size(queue)
assert size == -1
def test_basic_ack(self, channel):
delivery_tag = "test_delivery_tag"
ack_id = "test_ack_id"
queue = "test_queue"
subscription_path = (
"projects/project-id/subscriptions/test_subscription"
)
qdesc = QueueDescriptor(
name=queue,
topic_path="projects/project-id/topics/test_topic",
subscription_id="test_subscription",
subscription_path=subscription_path,
)
channel._queue_cache[queue] = qdesc
delivery_info = {
'gcpubsub_message': {
'queue': queue,
'ack_id': ack_id,
'subscription_path': subscription_path,
}
}
channel.qos.get = MagicMock(
return_value=MagicMock(delivery_info=delivery_info)
)
channel._do_ack = MagicMock()
channel.basic_ack(delivery_tag)
channel._do_ack.assert_called_once_with([ack_id], subscription_path)
assert ack_id not in qdesc.unacked_ids
def test_do_ack(self, channel):
ack_ids = ["ack_id1", "ack_id2"]
subscription_path = (
"projects/project-id/subscriptions/test_subscription"
)
channel.subscriber = MagicMock()
channel._do_ack(ack_ids, subscription_path)
assert subscription_path, (
ack_ids in channel.subscriber.acknowledge.call_args
)
def test_purge(self, channel):
queue = "test_queue"
subscription_path = f"projects/project-id/subscriptions/{queue}"
qdesc = QueueDescriptor(
name=queue,
topic_path="projects/project-id/topics/test_topic",
subscription_id="test_subscription",
subscription_path=subscription_path,
)
channel._queue_cache[channel.entity_name(queue)] = qdesc
channel.subscriber = MagicMock()
with patch.object(channel, '_size', return_value=10), patch(
'kombu.transport.gcpubsub.datetime.datetime'
) as dt_mock:
dt_mock.now.return_value = datetime(2021, 1, 1)
result = channel._purge(queue)
assert result == 10
channel.subscriber.seek.assert_called_once_with(
request={
"subscription": subscription_path,
"time": datetime(2021, 1, 1),
}
)
# Test the case where the queue is not in the cache
result = channel._purge("non_existent_queue")
assert result is None
def test_extend_unacked_deadline(self, channel):
queue = "test_queue"
subscription_path = (
"projects/project-id/subscriptions/test_subscription"
)
ack_ids = ["ack_id1", "ack_id2"]
qdesc = QueueDescriptor(
name=queue,
topic_path="projects/project-id/topics/test_topic",
subscription_id="test_subscription",
subscription_path=subscription_path,
)
channel.transport_options = {"ack_deadline_seconds": 240}
channel._queue_cache[channel.entity_name(queue)] = qdesc
qdesc.unacked_ids.extend(ack_ids)
channel._stop_extender.wait = MagicMock(side_effect=[False, True])
channel.subscriber.modify_ack_deadline = MagicMock()
channel._extend_unacked_deadline()
channel.subscriber.modify_ack_deadline.assert_called_once_with(
request={
"subscription": subscription_path,
"ack_ids": ack_ids,
"ack_deadline_seconds": 240,
}
)
for _ in ack_ids:
qdesc.unacked_ids.pop()
channel._stop_extender.wait = MagicMock(side_effect=[False, True])
modify_ack_deadline_calls = (
channel.subscriber.modify_ack_deadline.call_count
)
channel._extend_unacked_deadline()
assert (
channel.subscriber.modify_ack_deadline.call_count
== modify_ack_deadline_calls
)
def test_after_reply_message_received(self, channel):
queue = 'test-queue'
subscription_path = f'projects/test-project/subscriptions/{queue}'
channel.subscriber.subscription_path.return_value = subscription_path
channel.after_reply_message_received(queue)
# Check that the subscription path is added to _tmp_subscriptions
assert subscription_path in channel._tmp_subscriptions
def test_subscriber(self, channel):
assert channel.subscriber
def test_publisher(self, channel):
assert channel.publisher
def test_transport_options(self, channel):
assert channel.transport_options
def test_bulk_max_messages_default(self, channel):
assert channel.bulk_max_messages == channel.transport_options.get(
'bulk_max_messages'
)
def test_close(self, channel):
channel._tmp_subscriptions = {'sub1', 'sub2'}
channel._n_channels.dec.return_value = 0
with patch.object(
Channel._unacked_extender, 'join'
) as mock_join, patch(
'kombu.transport.virtual.Channel.close'
) as mock_super_close:
channel.close()
channel.subscriber.delete_subscription.assert_has_calls(
[
call(request={"subscription": 'sub1'}),
call(request={"subscription": 'sub2'}),
],
any_order=True,
)
channel._stop_extender.set.assert_called_once()
mock_join.assert_called_once()
mock_super_close.assert_called_once()
@pytest.fixture
def transport():
return Transport(client=MagicMock())
class test_Transport:
def test_driver_version(self, transport):
assert transport.driver_version()
def test_as_uri(self, transport):
result = transport.as_uri('gcpubsub://')
assert result == 'gcpubsub://'
def test_drain_events_timeout(self, transport):
transport.polling_interval = 4
with patch.object(
transport, '_drain_from_active_queues', side_effect=Empty
), patch(
'kombu.transport.gcpubsub.monotonic',
side_effect=[0, 1, 2, 3, 4, 5],
), patch(
'kombu.transport.gcpubsub.sleep'
) as mock_sleep:
with pytest.raises(socket_timeout):
transport.drain_events(None, timeout=3)
mock_sleep.assert_called()
def test_drain_events_no_timeout(self, transport):
with patch.object(
transport, '_drain_from_active_queues', side_effect=[Empty, None]
), patch(
'kombu.transport.gcpubsub.monotonic', side_effect=[0, 1]
), patch(
'kombu.transport.gcpubsub.sleep'
) as mock_sleep:
transport.drain_events(None, timeout=None)
mock_sleep.assert_called()
def test_drain_events_polling_interval(self, transport):
transport.polling_interval = 2
with patch.object(
transport, '_drain_from_active_queues', side_effect=[Empty, None]
), patch(
'kombu.transport.gcpubsub.monotonic', side_effect=[0, 1, 2]
), patch(
'kombu.transport.gcpubsub.sleep'
) as mock_sleep:
transport.drain_events(None, timeout=5)
mock_sleep.assert_called_with(2)
def test_drain_from_active_queues_empty(self, transport):
with patch.object(
transport, '_rm_empty_bulk_requests'
) as mock_rm_empty, patch.object(
transport, '_submit_get_bulk_requests'
) as mock_submit, patch(
'kombu.transport.gcpubsub.wait', return_value=(set(), set())
) as mock_wait:
with pytest.raises(Empty):
transport._drain_from_active_queues(timeout=10)
mock_rm_empty.assert_called_once()
mock_submit.assert_called_once_with(timeout=10)
mock_wait.assert_called_once()
def test_drain_from_active_queues_done(self, transport):
future = Future()
future.set_result(('queue', [{'properties': {'delivery_info': {}}}]))
with patch.object(
transport, '_rm_empty_bulk_requests'
) as mock_rm_empty, patch.object(
transport, '_submit_get_bulk_requests'
) as mock_submit, patch(
'kombu.transport.gcpubsub.wait', return_value=({future}, set())
) as mock_wait, patch.object(
transport, '_deliver'
) as mock_deliver:
transport._callbacks = {'queue'}
transport._drain_from_active_queues(timeout=10)
mock_rm_empty.assert_called_once()
mock_submit.assert_called_once_with(timeout=10)
mock_wait.assert_called_once()
mock_deliver.assert_called_once_with(
{'properties': {'delivery_info': {}}}, 'queue'
)
mock_deliver_call_count = mock_deliver.call_count
transport._callbacks = {}
transport._drain_from_active_queues(timeout=10)
assert mock_deliver_call_count == mock_deliver.call_count
def test_drain_from_active_queues_exception(self, transport):
future = Future()
future.set_exception(Exception("Test exception"))
with patch.object(
transport, '_rm_empty_bulk_requests'
) as mock_rm_empty, patch.object(
transport, '_submit_get_bulk_requests'
) as mock_submit, patch(
'kombu.transport.gcpubsub.wait', return_value=({future}, set())
) as mock_wait:
with pytest.raises(Empty):
transport._drain_from_active_queues(timeout=10)
mock_rm_empty.assert_called_once()
mock_submit.assert_called_once_with(timeout=10)
mock_wait.assert_called_once()
def test_rm_empty_bulk_requests(self, transport):
# Create futures with exceptions to simulate empty requests
future_with_exception = Future()
future_with_exception.set_exception(Exception("Test exception"))
transport._get_bulk_future_to_queue = {
future_with_exception: 'queue1',
}
transport._rm_empty_bulk_requests()
# Assert that the future with exception is removed
assert (
future_with_exception not in transport._get_bulk_future_to_queue
)
def test_submit_get_bulk_requests(self, transport):
channel_mock = MagicMock(spec=Channel)
channel_mock._active_queues = ['queue1', 'queue2']
transport.channels = [channel_mock]
with patch.object(
transport._pool, 'submit', return_value=MagicMock()
) as mock_submit:
transport._submit_get_bulk_requests(timeout=10)
# Check that submit was called twice, once for each queue
assert mock_submit.call_count == 2
mock_submit.assert_any_call(channel_mock._get_bulk, 'queue1', 10)
mock_submit.assert_any_call(channel_mock._get_bulk, 'queue2', 10)
def test_submit_get_bulk_requests_with_existing_futures(self, transport):
channel_mock = MagicMock(spec=Channel)
channel_mock._active_queues = ['queue1', 'queue2']
transport.channels = [channel_mock]
# Simulate existing futures
future_mock = MagicMock()
transport._get_bulk_future_to_queue = {future_mock: 'queue1'}
with patch.object(
transport._pool, 'submit', return_value=MagicMock()
) as mock_submit:
transport._submit_get_bulk_requests(timeout=10)
# Check that submit was called only once for the new queue
assert mock_submit.call_count == 1
mock_submit.assert_called_with(
channel_mock._get_bulk, 'queue2', 10
)

View File

@ -25,7 +25,9 @@ python =
[testenv]
sitepackages = False
setenv = C_DEBUG_TEST = 1
setenv =
C_DEBUG_TEST = 1
PIP_EXTRA_INDEX_URL=https://celery.github.io/celery-wheelhouse/repo/simple/
passenv =
DISTUTILS_USE_SDK
deps=