diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 6e9b8dfa..7e8c44a6 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -395,8 +395,11 @@ class Channel(virtual.Channel): def _put(self, queue, message, **kwargs): """Put message onto queue.""" q_url = self._new_queue(queue) - kwargs = {'QueueUrl': q_url, - 'MessageBody': AsyncMessage().encode(dumps(message))} + if self.sqs_base64_encoding: + body = AsyncMessage().encode(dumps(message)) + else: + body = dumps(message) + kwargs = {'QueueUrl': q_url, 'MessageBody': body} if queue.endswith('.fifo'): if 'MessageGroupId' in message['properties']: kwargs['MessageGroupId'] = \ @@ -420,22 +423,19 @@ class Channel(virtual.Channel): c.send_message(**kwargs) @staticmethod - def __b64_encoded(byte_string): + def _optional_b64_decode(byte_string): try: - return base64.b64encode( - base64.b64decode(byte_string) - ) == byte_string + data = base64.b64decode(byte_string) + if base64.b64encode(data) == byte_string: + return data + # else the base64 module found some embedded base64 content + # that should be ignored. except Exception: # pylint: disable=broad-except - return False + pass + return byte_string def _message_to_python(self, message, queue_name, queue): - body = message['Body'].encode() - try: - if self.__b64_encoded(body): - body = base64.b64decode(body) - except TypeError: - pass - + body = self._optional_b64_decode(message['Body'].encode()) payload = loads(bytes_to_str(body)) if queue_name in self._noack_queues: queue = self._new_queue(queue_name) @@ -837,6 +837,10 @@ class Channel(virtual.Channel): return self.transport_options.get('wait_time_seconds', self.default_wait_time_seconds) + @cached_property + def sqs_base64_encoding(self): + return self.transport_options.get('sqs_base64_encoding', True) + class Transport(virtual.Transport): """SQS Transport. diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 6056dd3d..ea261659 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -336,13 +336,13 @@ class test_Channel: with pytest.raises(Empty): self.channel._get_bulk(self.queue_name) - def test_is_base64_encoded(self): + def test_optional_b64_decode(self): raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \ b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' # noqa b64_enc = base64.b64encode(raw) - assert self.channel._Channel__b64_encoded(b64_enc) - assert not self.channel._Channel__b64_encoded(raw) - assert not self.channel._Channel__b64_encoded(b"test123") + assert self.channel._optional_b64_decode(b64_enc) == raw + assert self.channel._optional_b64_decode(raw) == raw + assert self.channel._optional_b64_decode(b"test123") == b"test123" def test_messages_to_python(self): from kombu.asynchronous.aws.sqs.message import Message