fix(sqs): don't crash on multiple predefined queues with aws sts session (#2224)

* chore(sqs): write the test case for multiple predefined queues with aws sts session

* fix(sqs): don't crash on multiple predefined queues with aws sts session

* refactor(sqs): make _new_predefined_queue_client_with_sts_session()
This commit is contained in:
Manjong Han 2025-01-14 00:49:43 +09:00 committed by GitHub
parent 4c64cdd39f
commit 83b296f011
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 22 deletions

View File

@ -766,34 +766,30 @@ class Channel(virtual.Channel):
return c
def _handle_sts_session(self, queue, q):
region = q.get('region', self.region)
if not hasattr(self, 'sts_expiration'): # STS token - token init
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
return self._new_predefined_queue_client_with_sts_session(queue, region)
# STS token - refresh if expired
elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow():
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=q.get('region', self.region),
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
return self._new_predefined_queue_client_with_sts_session(queue, region)
else: # STS token - ruse existing
if queue not in self._predefined_queue_clients:
return self._new_predefined_queue_client_with_sts_session(queue, region)
return self._predefined_queue_clients[queue]
def _new_predefined_queue_client_with_sts_session(self, queue, region):
sts_creds = self.generate_sts_session_token(
self.transport_options.get('sts_role_arn'),
self.transport_options.get('sts_token_timeout', 900))
self.sts_expiration = sts_creds['Expiration']
c = self._predefined_queue_clients[queue] = self.new_sqs_client(
region=region,
access_key_id=sts_creds['AccessKeyId'],
secret_access_key=sts_creds['SecretAccessKey'],
session_token=sts_creds['SessionToken'],
)
return c
def generate_sts_session_token(self, role_arn, token_expiry_seconds):
sts_client = boto3.client('sts')
sts_policy = sts_client.assume_role(

View File

@ -996,6 +996,34 @@ class test_Channel:
# Assert
mock_generate_sts_session_token.assert_not_called()
def test_sts_session_with_multiple_predefined_queues(self):
connection = Connection(transport=SQS.Transport, transport_options={
'predefined_queues': example_predefined_queues,
'sts_role_arn': 'test::arn'
})
channel = connection.channel()
sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel)
mock_generate_sts_session_token = Mock()
mock_new_sqs_client = Mock()
channel.new_sqs_client = mock_new_sqs_client
mock_generate_sts_session_token.return_value = {
'Expiration': datetime.utcnow() + timedelta(days=1),
'SessionToken': 123,
'AccessKeyId': 123,
'SecretAccessKey': 123
}
channel.generate_sts_session_token = mock_generate_sts_session_token
# Act
sqs(queue='queue-1')
sqs(queue='queue-2')
# Assert
mock_generate_sts_session_token.assert_called()
mock_new_sqs_client.assert_called()
def test_message_attribute(self):
message = 'my test message'
self.producer.publish(message, message_attributes={