Add an option to not base64-encode SQS messages.

Also simplify the base64 decoding logic so that we don't have to
run base64 decoding twice for every message.
This commit is contained in:
Shane Hathaway 2022-03-02 22:23:31 -07:00 committed by Asif Saif Uddin
parent 22adaaa38b
commit 9b505f4297
2 changed files with 22 additions and 18 deletions

View File

@ -395,8 +395,11 @@ class Channel(virtual.Channel):
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q_url = self._new_queue(queue)
kwargs = {'QueueUrl': q_url,
'MessageBody': AsyncMessage().encode(dumps(message))}
if self.sqs_base64_encoding:
body = AsyncMessage().encode(dumps(message))
else:
body = dumps(message)
kwargs = {'QueueUrl': q_url, 'MessageBody': body}
if queue.endswith('.fifo'):
if 'MessageGroupId' in message['properties']:
kwargs['MessageGroupId'] = \
@ -420,22 +423,19 @@ class Channel(virtual.Channel):
c.send_message(**kwargs)
@staticmethod
def __b64_encoded(byte_string):
def _optional_b64_decode(byte_string):
try:
return base64.b64encode(
base64.b64decode(byte_string)
) == byte_string
data = base64.b64decode(byte_string)
if base64.b64encode(data) == byte_string:
return data
# else the base64 module found some embedded base64 content
# that should be ignored.
except Exception: # pylint: disable=broad-except
return False
pass
return byte_string
def _message_to_python(self, message, queue_name, queue):
body = message['Body'].encode()
try:
if self.__b64_encoded(body):
body = base64.b64decode(body)
except TypeError:
pass
body = self._optional_b64_decode(message['Body'].encode())
payload = loads(bytes_to_str(body))
if queue_name in self._noack_queues:
queue = self._new_queue(queue_name)
@ -837,6 +837,10 @@ class Channel(virtual.Channel):
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
@cached_property
def sqs_base64_encoding(self):
return self.transport_options.get('sqs_base64_encoding', True)
class Transport(virtual.Transport):
"""SQS Transport.

View File

@ -336,13 +336,13 @@ class test_Channel:
with pytest.raises(Empty):
self.channel._get_bulk(self.queue_name)
def test_is_base64_encoded(self):
def test_optional_b64_decode(self):
raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \
b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' # noqa
b64_enc = base64.b64encode(raw)
assert self.channel._Channel__b64_encoded(b64_enc)
assert not self.channel._Channel__b64_encoded(raw)
assert not self.channel._Channel__b64_encoded(b"test123")
assert self.channel._optional_b64_decode(b64_enc) == raw
assert self.channel._optional_b64_decode(raw) == raw
assert self.channel._optional_b64_decode(b"test123") == b"test123"
def test_messages_to_python(self):
from kombu.asynchronous.aws.sqs.message import Message