mirror of https://github.com/celery/kombu.git
azure service bus: add managed identity support (#1641)
This commit is contained in:
parent
55370c7e4c
commit
6f8676630d
|
@ -30,11 +30,13 @@ Features
|
||||||
Connection String
|
Connection String
|
||||||
=================
|
=================
|
||||||
|
|
||||||
Connection string has the following format:
|
Connection string has the following formats:
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
|
azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE
|
||||||
|
azureservicebus://DefaultAzureIdentity@SERVICE_BUSNAMESPACE
|
||||||
|
azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE
|
||||||
|
|
||||||
Transport Options
|
Transport Options
|
||||||
=================
|
=================
|
||||||
|
@ -67,6 +69,13 @@ from azure.servicebus import (ServiceBusClient, ServiceBusMessage,
|
||||||
ServiceBusSender)
|
ServiceBusSender)
|
||||||
from azure.servicebus.management import ServiceBusAdministrationClient
|
from azure.servicebus.management import ServiceBusAdministrationClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
from azure.identity import (DefaultAzureCredential,
|
||||||
|
ManagedIdentityCredential)
|
||||||
|
except ImportError:
|
||||||
|
DefaultAzureCredential = None
|
||||||
|
ManagedIdentityCredential = None
|
||||||
|
|
||||||
from kombu.utils.encoding import bytes_to_str, safe_str
|
from kombu.utils.encoding import bytes_to_str, safe_str
|
||||||
from kombu.utils.json import dumps, loads
|
from kombu.utils.json import dumps, loads
|
||||||
from kombu.utils.objects import cached_property
|
from kombu.utils.objects import cached_property
|
||||||
|
@ -129,8 +138,10 @@ class Channel(virtual.Channel):
|
||||||
self.qos.restore_at_shutdown = False
|
self.qos.restore_at_shutdown = False
|
||||||
|
|
||||||
def _try_parse_connection_string(self) -> None:
|
def _try_parse_connection_string(self) -> None:
|
||||||
self._namespace, self._policy, self._sas_key = Transport.parse_uri(
|
self._namespace, self._credential = Transport.parse_uri(
|
||||||
self.conninfo.hostname)
|
self.conninfo.hostname)
|
||||||
|
if ":" in self._credential:
|
||||||
|
self._policy, self._sas_key = self._credential.split(':', 1)
|
||||||
|
|
||||||
# Convert
|
# Convert
|
||||||
endpoint = 'sb://' + self._namespace
|
endpoint = 'sb://' + self._namespace
|
||||||
|
@ -340,6 +351,7 @@ class Channel(virtual.Channel):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def queue_service(self) -> ServiceBusClient:
|
def queue_service(self) -> ServiceBusClient:
|
||||||
|
if self._connection_string:
|
||||||
return ServiceBusClient.from_connection_string(
|
return ServiceBusClient.from_connection_string(
|
||||||
self._connection_string,
|
self._connection_string,
|
||||||
retry_total=self.retry_total,
|
retry_total=self.retry_total,
|
||||||
|
@ -347,10 +359,24 @@ class Channel(virtual.Channel):
|
||||||
retry_backoff_max=self.retry_backoff_max
|
retry_backoff_max=self.retry_backoff_max
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return ServiceBusClient(
|
||||||
|
self._namespace,
|
||||||
|
self._credential,
|
||||||
|
retry_total=self.retry_total,
|
||||||
|
retry_backoff_factor=self.retry_backoff_factor,
|
||||||
|
retry_backoff_max=self.retry_backoff_max
|
||||||
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
|
def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
|
||||||
|
if self._connection_string:
|
||||||
return ServiceBusAdministrationClient.from_connection_string(
|
return ServiceBusAdministrationClient.from_connection_string(
|
||||||
self._connection_string)
|
self._connection_string
|
||||||
|
)
|
||||||
|
|
||||||
|
return ServiceBusAdministrationClient(
|
||||||
|
self._namespace, self._credential
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def conninfo(self):
|
def conninfo(self):
|
||||||
|
@ -417,25 +443,47 @@ class Transport(virtual.Transport):
|
||||||
# > 'rootpolicy:some/key@somenamespace'
|
# > 'rootpolicy:some/key@somenamespace'
|
||||||
uri = uri.replace('azureservicebus://', '')
|
uri = uri.replace('azureservicebus://', '')
|
||||||
# > 'rootpolicy:some/key', 'somenamespace'
|
# > 'rootpolicy:some/key', 'somenamespace'
|
||||||
policykeypair, namespace = uri.rsplit('@', 1)
|
credential, namespace = uri.rsplit('@', 1)
|
||||||
|
|
||||||
|
if "DefaultAzureCredential".lower() == credential.lower():
|
||||||
|
if DefaultAzureCredential is None:
|
||||||
|
raise ImportError('Azure Service Bus transport with a '
|
||||||
|
'DefaultAzureCredential requires the '
|
||||||
|
'azure-identity library')
|
||||||
|
credential = DefaultAzureCredential()
|
||||||
|
elif "ManagedIdentityCredential".lower() == credential.lower():
|
||||||
|
if ManagedIdentityCredential is None:
|
||||||
|
raise ImportError('Azure Service Bus transport with a '
|
||||||
|
'ManagedIdentityCredential requires the '
|
||||||
|
'azure-identity library')
|
||||||
|
credential = ManagedIdentityCredential()
|
||||||
|
else:
|
||||||
# > 'rootpolicy', 'some/key'
|
# > 'rootpolicy', 'some/key'
|
||||||
policy, sas_key = policykeypair.split(':', 1)
|
policy, sas_key = credential.split(':', 1)
|
||||||
|
credential = f"{policy}:{sas_key}"
|
||||||
|
|
||||||
# Validate ASB connection string
|
# Validate ASB connection string
|
||||||
if not all([namespace, policy, sas_key]):
|
if not all([namespace, credential]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Need a URI like '
|
'Need a URI like '
|
||||||
'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa
|
'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa
|
||||||
'or the azure Endpoint connection string'
|
'or the azure Endpoint connection string'
|
||||||
)
|
)
|
||||||
|
|
||||||
return namespace, policy, sas_key
|
return namespace, credential
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def as_uri(cls, uri: str, include_password=False, mask='**') -> str:
|
def as_uri(cls, uri: str, include_password=False, mask='**') -> str:
|
||||||
namespace, policy, sas_key = cls.parse_uri(uri)
|
namespace, credential = cls.parse_uri(uri)
|
||||||
|
if ":" in credential:
|
||||||
|
policy, sas_key = credential.split(':', 1)
|
||||||
return 'azureservicebus://{}:{}@{}'.format(
|
return 'azureservicebus://{}:{}@{}'.format(
|
||||||
policy,
|
policy,
|
||||||
sas_key if include_password else mask,
|
sas_key if include_password else mask,
|
||||||
namespace
|
namespace
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return 'azureservicebus://{}@{}'.format(
|
||||||
|
credential,
|
||||||
|
namespace
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue