diff --git a/kombu/transport/azureservicebus.py b/kombu/transport/azureservicebus.py index 01b12965..5aa7c03b 100644 --- a/kombu/transport/azureservicebus.py +++ b/kombu/transport/azureservicebus.py @@ -30,11 +30,13 @@ Features Connection String ================= -Connection string has the following format: +Connection string has the following formats: .. code-block:: azureservicebus://SAS_POLICY_NAME:SAS_KEY@SERVICE_BUSNAMESPACE + azureservicebus://DefaultAzureIdentity@SERVICE_BUSNAMESPACE + azureservicebus://ManagedIdentityCredential@SERVICE_BUSNAMESPACE Transport Options ================= @@ -67,6 +69,13 @@ from azure.servicebus import (ServiceBusClient, ServiceBusMessage, ServiceBusSender) from azure.servicebus.management import ServiceBusAdministrationClient +try: + from azure.identity import (DefaultAzureCredential, + ManagedIdentityCredential) +except ImportError: + DefaultAzureCredential = None + ManagedIdentityCredential = None + from kombu.utils.encoding import bytes_to_str, safe_str from kombu.utils.json import dumps, loads from kombu.utils.objects import cached_property @@ -129,8 +138,10 @@ class Channel(virtual.Channel): self.qos.restore_at_shutdown = False def _try_parse_connection_string(self) -> None: - self._namespace, self._policy, self._sas_key = Transport.parse_uri( + self._namespace, self._credential = Transport.parse_uri( self.conninfo.hostname) + if ":" in self._credential: + self._policy, self._sas_key = self._credential.split(':', 1) # Convert endpoint = 'sb://' + self._namespace @@ -340,8 +351,17 @@ class Channel(virtual.Channel): @cached_property def queue_service(self) -> ServiceBusClient: - return ServiceBusClient.from_connection_string( - self._connection_string, + if self._connection_string: + return ServiceBusClient.from_connection_string( + self._connection_string, + retry_total=self.retry_total, + retry_backoff_factor=self.retry_backoff_factor, + retry_backoff_max=self.retry_backoff_max + ) + + return ServiceBusClient( + self._namespace, + self._credential, retry_total=self.retry_total, retry_backoff_factor=self.retry_backoff_factor, retry_backoff_max=self.retry_backoff_max @@ -349,8 +369,14 @@ class Channel(virtual.Channel): @cached_property def queue_mgmt_service(self) -> ServiceBusAdministrationClient: - return ServiceBusAdministrationClient.from_connection_string( - self._connection_string) + if self._connection_string: + return ServiceBusAdministrationClient.from_connection_string( + self._connection_string + ) + + return ServiceBusAdministrationClient( + self._namespace, self._credential + ) @property def conninfo(self): @@ -417,25 +443,47 @@ class Transport(virtual.Transport): # > 'rootpolicy:some/key@somenamespace' uri = uri.replace('azureservicebus://', '') # > 'rootpolicy:some/key', 'somenamespace' - policykeypair, namespace = uri.rsplit('@', 1) - # > 'rootpolicy', 'some/key' - policy, sas_key = policykeypair.split(':', 1) + credential, namespace = uri.rsplit('@', 1) + + if "DefaultAzureCredential".lower() == credential.lower(): + if DefaultAzureCredential is None: + raise ImportError('Azure Service Bus transport with a ' + 'DefaultAzureCredential requires the ' + 'azure-identity library') + credential = DefaultAzureCredential() + elif "ManagedIdentityCredential".lower() == credential.lower(): + if ManagedIdentityCredential is None: + raise ImportError('Azure Service Bus transport with a ' + 'ManagedIdentityCredential requires the ' + 'azure-identity library') + credential = ManagedIdentityCredential() + else: + # > 'rootpolicy', 'some/key' + policy, sas_key = credential.split(':', 1) + credential = f"{policy}:{sas_key}" # Validate ASB connection string - if not all([namespace, policy, sas_key]): + if not all([namespace, credential]): raise ValueError( 'Need a URI like ' 'azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace} ' # noqa 'or the azure Endpoint connection string' ) - return namespace, policy, sas_key + return namespace, credential @classmethod def as_uri(cls, uri: str, include_password=False, mask='**') -> str: - namespace, policy, sas_key = cls.parse_uri(uri) - return 'azureservicebus://{}:{}@{}'.format( - policy, - sas_key if include_password else mask, + namespace, credential = cls.parse_uri(uri) + if ":" in credential: + policy, sas_key = credential.split(':', 1) + return 'azureservicebus://{}:{}@{}'.format( + policy, + sas_key if include_password else mask, + namespace + ) + + return 'azureservicebus://{}@{}'.format( + credential, namespace )