mirror of https://github.com/celery/kombu.git
azure service bus: add managed identity support (#1641)
This commit is contained in:
parent
55370c7e4c
commit
6f8676630d
|
@ -30,11 +30,13 @@ Features
|
|||
Connection String
|
||||
=================
|
||||
|
||||
Connection string has the following format:
|
||||
Connection string has the following formats:
|
||||
|
||||
.. code-block::
|
||||
|
||||
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
|
||||
azureservicebus://DefaultAzureIdentity@SERVICE_BUSNAMESPACE
|
||||
azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE
|
||||
|
||||
Transport Options
|
||||
=================
|
||||
|
@ -67,6 +69,13 @@ from azure.servicebus import (ServiceBusClient, ServiceBusMessage,
|
|||
ServiceBusSender)
|
||||
from azure.servicebus.management import ServiceBusAdministrationClient
|
||||
|
||||
try:
|
||||
from azure.identity import (DefaultAzureCredential,
|
||||
ManagedIdentityCredential)
|
||||
except ImportError:
|
||||
DefaultAzureCredential = None
|
||||
ManagedIdentityCredential = None
|
||||
|
||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||
from kombu.utils.json import dumps, loads
|
||||
from kombu.utils.objects import cached_property
|
||||
|
@ -129,8 +138,10 @@ class Channel(virtual.Channel):
|
|||
self.qos.restore_at_shutdown = False
|
||||
|
||||
def _try_parse_connection_string(self) -> None:
|
||||
self._namespace, self._policy, self._sas_key = Transport.parse_uri(
|
||||
self._namespace, self._credential = Transport.parse_uri(
|
||||
self.conninfo.hostname)
|
||||
if ":" in self._credential:
|
||||
self._policy, self._sas_key = self._credential.split(':', 1)
|
||||
|
||||
# Convert
|
||||
endpoint = 'sb://' + self._namespace
|
||||
|
@ -340,8 +351,17 @@ class Channel(virtual.Channel):
|
|||
|
||||
@cached_property
|
||||
def queue_service(self) -> ServiceBusClient:
|
||||
return ServiceBusClient.from_connection_string(
|
||||
self._connection_string,
|
||||
if self._connection_string:
|
||||
return ServiceBusClient.from_connection_string(
|
||||
self._connection_string,
|
||||
retry_total=self.retry_total,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
retry_backoff_max=self.retry_backoff_max
|
||||
)
|
||||
|
||||
return ServiceBusClient(
|
||||
self._namespace,
|
||||
self._credential,
|
||||
retry_total=self.retry_total,
|
||||
retry_backoff_factor=self.retry_backoff_factor,
|
||||
retry_backoff_max=self.retry_backoff_max
|
||||
|
@ -349,8 +369,14 @@ class Channel(virtual.Channel):
|
|||
|
||||
@cached_property
|
||||
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
|
||||
return ServiceBusAdministrationClient.from_connection_string(
|
||||
self._connection_string)
|
||||
if self._connection_string:
|
||||
return ServiceBusAdministrationClient.from_connection_string(
|
||||
self._connection_string
|
||||
)
|
||||
|
||||
return ServiceBusAdministrationClient(
|
||||
self._namespace, self._credential
|
||||
)
|
||||
|
||||
@property
|
||||
def conninfo(self):
|
||||
|
@ -417,25 +443,47 @@ class Transport(virtual.Transport):
|
|||
# > 'rootpolicy:some/key@somenamespace'
|
||||
uri = uri.replace('azureservicebus://', '')
|
||||
# > 'rootpolicy:some/key', 'somenamespace'
|
||||
policykeypair, namespace = uri.rsplit('@', 1)
|
||||
# > 'rootpolicy', 'some/key'
|
||||
policy, sas_key = policykeypair.split(':', 1)
|
||||
credential, namespace = uri.rsplit('@', 1)
|
||||
|
||||
if "DefaultAzureCredential".lower() == credential.lower():
|
||||
if DefaultAzureCredential is None:
|
||||
raise ImportError('Azure Service Bus transport with a '
|
||||
'DefaultAzureCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = DefaultAzureCredential()
|
||||
elif "ManagedIdentityCredential".lower() == credential.lower():
|
||||
if ManagedIdentityCredential is None:
|
||||
raise ImportError('Azure Service Bus transport with a '
|
||||
'ManagedIdentityCredential requires the '
|
||||
'azure-identity library')
|
||||
credential = ManagedIdentityCredential()
|
||||
else:
|
||||
# > 'rootpolicy', 'some/key'
|
||||
policy, sas_key = credential.split(':', 1)
|
||||
credential = f"{policy}:{sas_key}"
|
||||
|
||||
# Validate ASB connection string
|
||||
if not all([namespace, policy, sas_key]):
|
||||
if not all([namespace, credential]):
|
||||
raise ValueError(
|
||||
'Need a URI like '
|
||||
'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa
|
||||
'or the azure Endpoint connection string'
|
||||
)
|
||||
|
||||
return namespace, policy, sas_key
|
||||
return namespace, credential
|
||||
|
||||
@classmethod
|
||||
def as_uri(cls, uri: str, include_password=False, mask='**') -> str:
|
||||
namespace, policy, sas_key = cls.parse_uri(uri)
|
||||
return 'azureservicebus://{}:{}@{}'.format(
|
||||
policy,
|
||||
sas_key if include_password else mask,
|
||||
namespace, credential = cls.parse_uri(uri)
|
||||
if ":" in credential:
|
||||
policy, sas_key = credential.split(':', 1)
|
||||
return 'azureservicebus://{}:{}@{}'.format(
|
||||
policy,
|
||||
sas_key if include_password else mask,
|
||||
namespace
|
||||
)
|
||||
|
||||
return 'azureservicebus://{}@{}'.format(
|
||||
credential,
|
||||
namespace
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue