mirror of https://github.com/celery/kombu.git
fix(sqs): don't crash on multiple predefined queues with aws sts session (#2224)
* chore(sqs): write the test case for multiple predefined queues with aws sts session * fix(sqs): don't crash on multiple predefined queues with aws sts session * refactor(sqs): make _new_predefined_queue_client_with_sts_session()
This commit is contained in:
parent
4c64cdd39f
commit
83b296f011
|
@ -766,34 +766,30 @@ class Channel(virtual.Channel):
|
|||
return c
|
||||
|
||||
def _handle_sts_session(self, queue, q):
|
||||
region = q.get('region', self.region)
|
||||
if not hasattr(self, 'sts_expiration'): # STS token - token init
|
||||
sts_creds = self.generate_sts_session_token(
|
||||
self.transport_options.get('sts_role_arn'),
|
||||
self.transport_options.get('sts_token_timeout', 900))
|
||||
self.sts_expiration = sts_creds['Expiration']
|
||||
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
|
||||
region=q.get('region', self.region),
|
||||
access_key_id=sts_creds['AccessKeyId'],
|
||||
secret_access_key=sts_creds['SecretAccessKey'],
|
||||
session_token=sts_creds['SessionToken'],
|
||||
)
|
||||
return c
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
# STS token - refresh if expired
|
||||
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
|
||||
sts_creds = self.generate_sts_session_token(
|
||||
self.transport_options.get('sts_role_arn'),
|
||||
self.transport_options.get('sts_token_timeout', 900))
|
||||
self.sts_expiration = sts_creds['Expiration']
|
||||
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
|
||||
region=q.get('region', self.region),
|
||||
access_key_id=sts_creds['AccessKeyId'],
|
||||
secret_access_key=sts_creds['SecretAccessKey'],
|
||||
session_token=sts_creds['SessionToken'],
|
||||
)
|
||||
return c
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
else: # STS token - ruse existing
|
||||
if queue not in self._predefined_queue_clients:
|
||||
return self._new_predefined_queue_client_with_sts_session(queue, region)
|
||||
return self._predefined_queue_clients[queue]
|
||||
|
||||
def _new_predefined_queue_client_with_sts_session(self, queue, region):
|
||||
sts_creds = self.generate_sts_session_token(
|
||||
self.transport_options.get('sts_role_arn'),
|
||||
self.transport_options.get('sts_token_timeout', 900))
|
||||
self.sts_expiration = sts_creds['Expiration']
|
||||
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
|
||||
region=region,
|
||||
access_key_id=sts_creds['AccessKeyId'],
|
||||
secret_access_key=sts_creds['SecretAccessKey'],
|
||||
session_token=sts_creds['SessionToken'],
|
||||
)
|
||||
return c
|
||||
|
||||
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
|
||||
sts_client = boto3.client('sts')
|
||||
sts_policy = sts_client.assume_role(
|
||||
|
|
|
@ -996,6 +996,34 @@ class test_Channel:
|
|||
# Assert
|
||||
mock_generate_sts_session_token.assert_not_called()
|
||||
|
||||
def test_sts_session_with_multiple_predefined_queues(self):
|
||||
connection = Connection(transport=SQS.Transport, transport_options={
|
||||
'predefined_queues': example_predefined_queues,
|
||||
'sts_role_arn': 'test::arn'
|
||||
})
|
||||
channel = connection.channel()
|
||||
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
|
||||
|
||||
mock_generate_sts_session_token = Mock()
|
||||
mock_new_sqs_client = Mock()
|
||||
channel.new_sqs_client = mock_new_sqs_client
|
||||
mock_generate_sts_session_token.return_value = {
|
||||
'Expiration': datetime.utcnow() + timedelta(days=1),
|
||||
'SessionToken': 123,
|
||||
'AccessKeyId': 123,
|
||||
'SecretAccessKey': 123
|
||||
}
|
||||
|
||||
channel.generate_sts_session_token = mock_generate_sts_session_token
|
||||
|
||||
# Act
|
||||
sqs(queue='queue-1')
|
||||
sqs(queue='queue-2')
|
||||
|
||||
# Assert
|
||||
mock_generate_sts_session_token.assert_called()
|
||||
mock_new_sqs_client.assert_called()
|
||||
|
||||
def test_message_attribute(self):
|
||||
message = 'my test message'
|
||||
self.producer.publish(message, message_attributes={
|
||||
|
|
Loading…
Reference in New Issue