diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 0c8d1ee4..7d3fa0cc 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -766,34 +766,30 @@ class Channel(virtual.Channel): return c def _handle_sts_session(self, queue, q): + region = q.get('region', self.region) if not hasattr(self, 'sts_expiration'): # STS token - token init - sts_creds = self.generate_sts_session_token( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900)) - self.sts_expiration = sts_creds['Expiration'] - c = self._predefined_queue_clients[queue] = self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], - ) - return c + return self._new_predefined_queue_client_with_sts_session(queue, region) # STS token - refresh if expired elif self.sts_expiration.replace(tzinfo=None) < datetime.utcnow(): - sts_creds = self.generate_sts_session_token( - self.transport_options.get('sts_role_arn'), - self.transport_options.get('sts_token_timeout', 900)) - self.sts_expiration = sts_creds['Expiration'] - c = self._predefined_queue_clients[queue] = self.new_sqs_client( - region=q.get('region', self.region), - access_key_id=sts_creds['AccessKeyId'], - secret_access_key=sts_creds['SecretAccessKey'], - session_token=sts_creds['SessionToken'], - ) - return c + return self._new_predefined_queue_client_with_sts_session(queue, region) else: # STS token - ruse existing + if queue not in self._predefined_queue_clients: + return self._new_predefined_queue_client_with_sts_session(queue, region) return self._predefined_queue_clients[queue] + def _new_predefined_queue_client_with_sts_session(self, queue, region): + sts_creds = self.generate_sts_session_token( + self.transport_options.get('sts_role_arn'), + self.transport_options.get('sts_token_timeout', 900)) + self.sts_expiration = sts_creds['Expiration'] + c = self._predefined_queue_clients[queue] = self.new_sqs_client( + region=region, + access_key_id=sts_creds['AccessKeyId'], + secret_access_key=sts_creds['SecretAccessKey'], + session_token=sts_creds['SessionToken'], + ) + return c + def generate_sts_session_token(self, role_arn, token_expiry_seconds): sts_client = boto3.client('sts') sts_policy = sts_client.assume_role( diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index b82be5aa..b6a1d6ae 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -996,6 +996,34 @@ class test_Channel: # Assert mock_generate_sts_session_token.assert_not_called() + def test_sts_session_with_multiple_predefined_queues(self): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': example_predefined_queues, + 'sts_role_arn': 'test::arn' + }) + channel = connection.channel() + sqs = SQS_Channel_sqs.__get__(channel, SQS.Channel) + + mock_generate_sts_session_token = Mock() + mock_new_sqs_client = Mock() + channel.new_sqs_client = mock_new_sqs_client + mock_generate_sts_session_token.return_value = { + 'Expiration': datetime.utcnow() + timedelta(days=1), + 'SessionToken': 123, + 'AccessKeyId': 123, + 'SecretAccessKey': 123 + } + + channel.generate_sts_session_token = mock_generate_sts_session_token + + # Act + sqs(queue='queue-1') + sqs(queue='queue-2') + + # Assert + mock_generate_sts_session_token.assert_called() + mock_new_sqs_client.assert_called() + def test_message_attribute(self): message = 'my test message' self.producer.publish(message, message_attributes={